#!/usr/bin/env python
# -*- coding: UTF-8 -*-
#
# Copyright 2022-2024 NXP
#
# SPDX-License-Identifier: BSD-3-Clause

"""CLI application for various cryptographic operations."""

import hashlib
import logging
import os
import sys
from typing import Any, Dict, List, Optional, Union

import click
from click_option_group import RequiredMutuallyExclusiveOptionGroup, optgroup

from spsdk.apps.nxpcertgen import main as cert_gen_main
from spsdk.apps.utils import spsdk_logger
from spsdk.apps.utils.common_cli_options import (
    CommandsTreeGroup,
    spsdk_apps_common_options,
    spsdk_family_option,
    spsdk_output_option,
)
from spsdk.apps.utils.utils import SPSDKAppError, catch_spsdk_error
from spsdk.crypto.hash import EnumHashAlgorithm
from spsdk.crypto.keys import (
    ECDSASignature,
    PrivateKey,
    PrivateKeyEcc,
    PrivateKeyRsa,
    PrivateKeySM2,
    PublicKey,
    PublicKeyEcc,
    SPSDKKeyPassphraseMissing,
    get_ecc_curve,
    get_supported_keys_generators,
    prompt_for_passphrase,
)
from spsdk.crypto.signature_provider import get_signature_provider
from spsdk.crypto.types import SPSDKEncoding
from spsdk.crypto.utils import extract_public_key
from spsdk.exceptions import SPSDKError, SPSDKIndexError, SPSDKSyntaxError, SPSDKValueError
from spsdk.utils.crypto.rot import Rot
from spsdk.utils.misc import Endianness, load_binary, write_file

logger = logging.getLogger(__name__)


@click.group(name="nxpcrypto", no_args_is_help=True, cls=CommandsTreeGroup)
@spsdk_apps_common_options
def main(log_level: int) -> None:
    """Collection of utilities for cryptographic operations."""
    spsdk_logger.install(level=log_level)


@main.command(name="digest", no_args_is_help=True)
@click.option(
    "-h",
    "--hash",
    "hash_name",
    required=True,
    type=click.Choice(list(hashlib.algorithms_available), case_sensitive=False),
    help="Name of a hash to use.",
)
@click.option(
    "-i",
    "--input-file",
    type=click.Path(exists=True, dir_okay=False),
    required=True,
    help="Path to a file to digest.",
)
@click.option(
    "-c",
    "--compare",
    metavar="PATH | DIGEST",
    help="Reference digest to compare. It may be directly on the command line or fetched from a file.",
)
def digest(hash_name: str, input_file: str, compare: str) -> None:
    """Computes digest/hash of the given file."""
    data = load_binary(input_file)
    hasher = hashlib.new(hash_name.lower())
    hasher.update(data)
    hexdigest = hasher.hexdigest()
    click.echo(f"{hash_name.upper()}({input_file})= {hexdigest}")
    if compare:
        # assume comparing to a file
        if os.path.isfile(compare):
            with open(compare) as f:
                compare_data = f.readline().strip()
                # assume format generated by openssl
                if "=" in compare_data:
                    ref = compare_data.split("=")[-1].strip()
                # assume hash is on the fist line
                else:
                    ref = compare_data
        else:
            ref = compare
        if ref.lower() == hexdigest.lower():
            click.echo("Digests are the same.")
        else:
            raise SPSDKAppError("Digests differ!")


@main.group(name="rot", no_args_is_help=True)
def rot_group() -> None:
    """Group of RoT commands."""


@rot_group.command(name="export", no_args_is_help=True)
@spsdk_family_option(families=Rot.get_supported_families())
@click.option(
    "-k",
    "--key",
    type=click.Path(exists=True, dir_okay=False),
    multiple=True,
    help="Path to one or multiple keys or certificates.",
)
@click.option(
    "-p",
    "--password",
    help="Password when using encrypted private keys.",
)
@spsdk_output_option(required=False)
def export(family: str, key: List[str], password: str, output: str) -> None:
    """Export RoT table."""
    _rot = Rot(family, keys_or_certs=key, password=password)
    rot_hash = _rot.export()
    if output:
        write_file(rot_hash, path=output, mode="wb")
        click.echo(f"Result has been stored in: {output}")
    click.echo(f"RoT table: {rot_hash.hex()}")


@rot_group.command(name="calculate-hash", no_args_is_help=True)
@spsdk_family_option(families=Rot.get_supported_families())
@click.option(
    "-k",
    "--key",
    type=click.Path(exists=True, dir_okay=False),
    multiple=True,
    help="Path to one or multiple keys or certificates.",
)
@click.option(
    "-p",
    "--password",
    help="Password when using encrypted private keys.",
)
@spsdk_output_option(required=False)
def calculate_hash(family: str, key: List[str], password: str, output: str) -> None:
    """Calculate RoT hash."""
    _rot = Rot(family, keys_or_certs=key, password=password)
    rot_hash = _rot.calculate_hash()
    if output:
        write_file(rot_hash, path=output, mode="wb")
        click.echo(f"Result has been stored in: {output}")
    click.echo(f"RoT hash: {rot_hash.hex()}")


@main.group(name="cert", no_args_is_help=True)
def cert() -> None:
    """Group of command for working with x509 certificates."""


cert.add_command(cert_gen_main.commands["generate"], name="generate")
cert.add_command(cert_gen_main.commands["get-template"], name="get-template")
cert.add_command(cert_gen_main.commands["verify"], name="verify")


@main.group(name="key", no_args_is_help=True)
def key_group() -> None:
    """Group of commands for working with asymmetric keys."""


@key_group.command(name="generate", no_args_is_help=True)
@click.option(
    "-k",
    "--key-type",
    type=click.Choice(list(get_supported_keys_generators()), case_sensitive=False),
    metavar="KEY-TYPE",
    help=f"""\b
        Set of the supported key types.

        Note: NXP DAT protocol is using encryption keys by this table:

        NXP Protocol Version                Key Type
        1.0                                 RSA 2048
        1.1                                 RSA 4096
        2.0                                 SECP256R1
        2.1                                 SECP384R1
        2.2                                 SECP521R1

        All possible options:
        {", ".join(list(get_supported_keys_generators()))}.
        """,
)
@click.option(
    "--password",
    "password",
    metavar="PASSWORD",
    help="Password with which the output file will be encrypted. "
    "If not provided, the output will be unencrypted.",
)
@spsdk_output_option(force=True)
@click.option("-e", "--encoding", type=click.Choice(list(SPSDKEncoding.all())), default="PEM")
def key_generate(key_type: str, output: str, password: str, encoding: str) -> None:
    """NXP Key Generator Tool."""
    key_param = key_type.lower().strip()
    encoding_param = encoding.upper().strip()
    encoding_enum = SPSDKEncoding.all()[encoding_param]

    pub_key_path = os.path.splitext(output)[0] + ".pub"

    generators = get_supported_keys_generators()
    func, params = generators[key_param]

    private_key = func(**params)
    public_key = private_key.get_public_key()

    private_key.save(output, password if password else None, encoding=encoding_enum)
    public_key.save(pub_key_path, encoding=encoding_enum)

    click.echo(f"The key pair has been created: {(pub_key_path)}, {output}")


@key_group.command(name="convert", no_args_is_help=True)
@click.option(
    "-e",
    "--encoding",
    type=click.Choice(["PEM", "DER", "RAW"], case_sensitive=False),
    required=True,
    help="Desired output format.",
)
@click.option(
    "-i",
    "--input-file",
    type=click.Path(exists=True, dir_okay=False),
    required=True,
    help="Path to key file to convert.",
)
@spsdk_output_option()
@click.option(
    "-p",
    "--puk",
    is_flag=True,
    default=False,
    help="Extract public key instead of converting private key.",
)
def convert(encoding: str, input_file: str, output: str, puk: bool) -> None:
    """Convert Asymmetric key into various formats."""
    key_data = load_binary(input_file)
    key = reconstruct_key(key_data=key_data)
    if puk and isinstance(key, (PrivateKeyRsa, PrivateKeyEcc, PrivateKeySM2)):
        key = key.get_public_key()

    if encoding in ["PEM", "DER"]:
        encoding_type = {"PEM": SPSDKEncoding.PEM, "DER": SPSDKEncoding.DER}[encoding]
        out_data = key.export(encoding=encoding_type)
    elif encoding == "RAW":
        if not isinstance(key, (PrivateKeyEcc, PublicKeyEcc)):
            raise SPSDKError("Converting to RAW is supported only for ECC keys")
        key_size = key.key_size // 8
        if isinstance(key, PrivateKeyEcc):
            out_data = key.d.to_bytes(key_size, byteorder=Endianness.BIG.value)
        else:
            x = key.x.to_bytes(key_size, byteorder=Endianness.BIG.value)
            y = key.y.to_bytes(key_size, byteorder=Endianness.BIG.value)
            out_data = x + y
    else:
        raise SPSDKAppError("Desired output encoding format must be specified by -e/--encoding")

    write_file(out_data, output, mode="wb")


@key_group.command(name="verify", no_args_is_help=True)
@click.option(
    "-k1",
    "--key1",
    required=True,
    type=click.Path(exists=True, dir_okay=False),
    help="Path to key to verify.",
)
@click.option(
    "-k2",
    "--key2",
    required=True,
    type=click.Path(exists=True, dir_okay=False),
    help="Path to key for verification.",
)
def key_verify(key1: str, key2: str) -> None:
    """Check whether provided keys form a key pair or represent the same key.

    The key could be private key, public key, or certificate. All combination are allowed.
    In case of certificates, the public key within certificate is considered.
    To verify certificate signature use `nxpcrypto cert verify`.
    """
    if extract_public_key(key1) == extract_public_key(key2):
        click.echo("Keys match.")
    else:
        raise SPSDKAppError("Keys are NOT a valid pair!")


def reconstruct_key(
    key_data: bytes,
) -> Union[PrivateKey, PublicKey]:
    """Reconstruct Crypto key from PEM,DER or RAW data."""
    try:
        return PrivateKey.parse(key_data)
    except SPSDKError:
        pass
    try:
        return PublicKey.parse(key_data)
    except SPSDKError:
        pass
    # attempt to reconstruct key from raw data
    key_length = len(key_data)
    curve = get_ecc_curve(key_length)
    # everything under 49 bytes is a private key
    if key_length <= 48:
        # pylint: disable=invalid-name   # 'd' is regular name for private key number
        d = int.from_bytes(key_data, byteorder=Endianness.BIG.value)
        return PrivateKeyEcc.recreate(d=d, curve=curve)

    # public keys in binary form have exact sizes
    if key_length in [64, 96]:
        coord_length = key_length // 2
        x = int.from_bytes(key_data[:coord_length], byteorder=Endianness.BIG.value)
        y = int.from_bytes(key_data[coord_length:], byteorder=Endianness.BIG.value)
        return PublicKeyEcc.recreate(coor_x=x, coor_y=y, curve=curve)
    raise SPSDKError(f"Can't recognize key with length {key_length}")


@main.group(name="signature", no_args_is_help=True)
def signature_group() -> None:
    """Group of commands for working with signature."""


def cut_off_data_regions(data: bytes, regions: List[str]) -> bytes:
    """Get the data chunks from the input data.

    The regions are individual string written in python-like syntax. For example '[:0x10]'
    """
    if not regions:
        return data
    data_chunks = bytes()
    for region in regions:
        try:
            # pylint: disable=eval-used
            data_chunk = eval(f"data{region}")
            # if the region was defined as single index such as [0]
            if isinstance(data_chunk, int):
                data_chunk = data_chunk.to_bytes(1, Endianness.BIG.value)
            assert isinstance(data_chunk, bytes)
            data_chunks += data_chunk
        except (SyntaxError, NameError) as exc:
            raise SPSDKSyntaxError(f"Invalid region expression '{region}'") from exc
        except IndexError as exc:
            raise SPSDKIndexError(
                f"The region expression '{region}' is outside the data length {len(data)}"
            ) from exc
    return data_chunks


@signature_group.command(name="create", no_args_is_help=True)
@optgroup.group("Signee type", cls=RequiredMutuallyExclusiveOptionGroup)
@optgroup.option(
    "-k",
    "--private-key",
    type=click.Path(exists=True, dir_okay=False),
    help=f"""\b
        Path to private key to be used for signing.
        Supported private keys:
        {", ".join(list(get_supported_keys_generators()))}.
        """,
)
@optgroup.option(
    "-sp",
    "--signature-provider",
    type=click.STRING,
    help="Signature provider configuration string.",
)
@click.option(
    "-p",
    "--password",
    type=click.STRING,
    help="Password when using encrypted private keys.",
)
@click.option(
    "-a",
    "--algorithm",
    type=click.Choice(EnumHashAlgorithm.labels()),
    help="Hash algorithm used when signing the message.",
)
@click.option(
    "-i",
    "--input-file",
    required=True,
    type=click.Path(exists=False, dir_okay=False),
    help="Path to file containing binary data to be signed.",
)
@click.option(
    "-e",
    "--encoding",
    type=click.Choice([SPSDKEncoding.NXP.value, SPSDKEncoding.DER.value]),
    default=SPSDKEncoding.DER.value,
    help="Encoding of output signature. This option is applicable only when signing with ECC keys.",
)
@click.option(
    "-pp",
    "--pss-padding",
    is_flag=True,
    default=False,
    help="Use PSS padding in case of RSA",
)
@click.option(
    "-r",
    "--regions",
    type=click.STRING,
    multiple=True,
    help="""\b
        Region(s) of data that will be signed. Multiple regions can be specified.

        Format of region option is similar to Python's list indices syntax:

        +--------------+--------------------------+
        | [1]          | Byte with index 1        |
        +--------------+--------------------------+
        | [:20]        | Fist 20 bytes            |
        +--------------+--------------------------+
        | [0x10:0x20]  | Between 0x10 and 0x20    |
        +--------------+--------------------------+
        | [-20:]       | Last 20 bytes            |
        +--------------+--------------------------+
        """,
)
@spsdk_output_option(force=True)
def signature_create(
    private_key: str,
    signature_provider: str,
    password: str,
    algorithm: str,
    input_file: str,
    encoding: str,
    pss_padding: bool,
    regions: List[str],
    output: str,
) -> None:
    """Sign the data with given private key."""
    signature = signature_create_command(
        private_key,
        signature_provider,
        password,
        algorithm,
        input_file,
        encoding,
        pss_padding,
        regions,
    )
    write_file(signature, output, mode="wb")
    click.echo(f"The data have been signed. Signature saved to: {output}")


def signature_create_command(
    private_key: Optional[str],
    signature_provider_cfg: Optional[str],
    password: Optional[str],
    algorithm: Optional[str],
    input_file: str,
    encoding_str: str,
    pss_padding: bool,
    regions: List[str],
) -> bytes:
    """Sign the data with given private key."""
    data = load_binary(input_file, search_paths=["."])
    data = cut_off_data_regions(data, regions)
    hash_alg = EnumHashAlgorithm.from_label(algorithm) if algorithm else None
    if signature_provider_cfg:
        signature_provider = get_signature_provider(
            signature_provider_cfg, search_paths=["."], pss_padding=pss_padding
        )
        if hash_alg:
            logger.warning("Hash algorithm was not applied when using signature provider.")
        if password:
            logger.warning("Password was not applied when using signature provider.")
        signature = signature_provider.get_signature(data)
    else:
        assert private_key
        try:
            prv_key = PrivateKey.load(private_key, password)
        except SPSDKKeyPassphraseMissing:
            prv_key = PrivateKey.load(private_key, prompt_for_passphrase())
        extra_params: Dict[str, Any] = {"pss_padding": pss_padding}
        if hash_alg:
            if not isinstance(prv_key, PrivateKeySM2):
                extra_params["algorithm"] = hash_alg
            else:
                if hash_alg != EnumHashAlgorithm.SM3:
                    logger.warning("Only SM3 hash algorithm is supported for OSCCA")
        signature = prv_key.sign(data, **extra_params)
    encoding = SPSDKEncoding.all()[encoding_str.upper().strip()]
    try:
        ecc_signature = ECDSASignature.parse(signature)
        return ecc_signature.export(encoding=encoding)
    except SPSDKValueError:
        # Not an ECC signature
        parameter_source = click.get_current_context().get_parameter_source("encoding")
        assert parameter_source
        if parameter_source.name == "COMMANDLINE":
            logger.warning("Signature encoding is supported only for ECC keys.")
        return signature


@signature_group.command(name="verify", no_args_is_help=True)
@click.option(
    "-k",
    "--public-key",
    required=True,
    type=click.Path(exists=True, dir_okay=False),
    help=f"""\b
        Path to public key to be used for verification.

        Supported public keys:
        {", ".join(list(get_supported_keys_generators()))}.
        """,
)
@click.option(
    "-a",
    "--algorithm",
    type=click.Choice(EnumHashAlgorithm.labels()),
    help="Hash algorithm used when signing the message. If not set, default algorithm will be used.",
)
@click.option(
    "-i",
    "--input-file",
    required=True,
    type=click.Path(exists=False, dir_okay=False),
    help="Path to file containing original binary data.",
)
@click.option(
    "-s",
    "--signature",
    required=True,
    type=click.Path(exists=True, dir_okay=False),
    help="Path to file containing data signature.",
)
@click.option(
    "-pp",
    "--pss-padding",
    is_flag=True,
    default=False,
    help="Indicate whether the signature uses PSS padding in case of RSA",
)
@click.option(
    "-r",
    "--regions",
    type=click.STRING,
    multiple=True,
    help="""\b
        Region(s) of data that will be signed. Multiple regions can be specified.

        Format of region option is similar to Python's list indices syntax:

        +--------------+--------------------------+
        | [1]          | Byte with index 1        |
        +--------------+--------------------------+
        | [:20]        | Fist 20 bytes            |
        +--------------+--------------------------+
        | [0x10:0x20]  | Between 0x10 and 0x20    |
        +--------------+--------------------------+
        | [-20:]       | Last 20 bytes            |
        +--------------+--------------------------+
        """,
)
def signature_verify(
    public_key: str,
    algorithm: Optional[str],
    input_file: str,
    signature: str,
    pss_padding: bool,
    regions: List[str],
) -> None:
    """Verify the given signature with public key."""
    result = signature_verify_command(
        public_key, algorithm, input_file, signature, pss_padding, regions
    )
    click.echo(f"Signature {'IS' if result else 'IS NOT'} matching the public key.")


def signature_verify_command(
    public_key: str,
    algorithm: Optional[str],
    input_file: str,
    signature: str,
    pss_padding: bool,
    regions: List[str],
) -> bool:
    """Verify the given signature with public key."""
    public = PublicKey.load(public_key)
    extra_params: Dict[str, Any] = {"pss_padding": pss_padding}
    if algorithm:
        extra_params["algorithm"] = EnumHashAlgorithm.from_label(algorithm)
    signature_bin = load_binary(signature)
    data = load_binary(input_file)
    data = cut_off_data_regions(data, regions)
    result = public.verify_signature(signature_bin, data, **extra_params)
    return result


@catch_spsdk_error
def safe_main() -> None:
    """Call the main function."""
    sys.exit(main())  # pylint: disable=no-value-for-parameter


if __name__ == "__main__":
    safe_main()
