# =============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All rights reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================
import argparse

from qti.aisw.accuracy_debugger.argparser.parser import Parser
from qti.aisw.accuracy_debugger.common_config import (
    ConverterInputArguments,
    NetRunnerInputArguments,
    QuantizerInputArguments,
    RemoteHostDetails,
)
from qti.aisw.tools.core.modules.api.definitions.common import BackendType, OpPackageIdentifier
from qti.aisw.tools.core.modules.context_bin_gen.context_bin_gen_module import GenerateConfig
from qti.aisw.tools.core.modules.converter.converter_module import (
    InputTensorConfig,
    OutputTensorConfig,
)
from qti.aisw.tools.core.utilities.devices.api.device_definitions import (
    DevicePlatformType,
    RemoteDeviceIdentifier,
)


class InferenceEngineParser(Parser):
    """This is a parser for Inference Engine tool."""

    def __init__(self, component="inference_engine"):
        super().__init__(component)

    def _initialize(self):
        """Create parser with inference engine tool specific arguments"""
        super()._initialize()
        converter_args = self._parser.add_argument_group("converter arguments")
        quantizer_args = self._parser.add_argument_group("quantizer arguments")
        net_run_args = self._parser.add_argument_group("netrun arguments")
        offline_prepare_args = self._parser.add_argument_group("offline prepare arguments")

        self.required.add_argument(
            "--input_model", type=str, required=True, help="Path to the source model/dlc/bin file"
        )
        converter_args.add_argument(
            "--desired_input_shape",
            "--input_tensor",
            dest="desired_input_shape",
            nargs="+",
            action="append",
            required=False,
            default=None,
            help="The name,dimension,datatype and layout of all the input buffers to the network "
            "specified in the format [input_name comma-separated-dimensions data-type layout]. "
            "Dimension, datatype and layout are optional."
            "for example: 'data' 1,224,224,3. Note that the quotes should always be included in "
            "order to handle special characters, spaces, etc. "
            "For multiple inputs, specify multiple --desired_input_shape on the command line like: "
            '--desired_input_shape "data1" 1,224,224,3 float32 '
            '--desired_input_shape "data2" 1,50,100,3 int64 ',
        )
        converter_args.add_argument(
            "--output_tensor",
            type=str,
            required=False,
            action="append",
            help="Name of the graph's specified output tensor(s).",
        )
        converter_args.add_argument(
            "--converter_float_bitwidth",
            type=int,
            required=False,
            default=32,
            choices=[32, 16],
            help="Use this option to convert the graph to the specified float \
                bitwidth, either 32 (default) or 16.",
        )
        converter_args.add_argument(
            "--float_bias_bitwidth",
            default=32,
            required=False,
            type=int,
            choices=[32, 16],
            help="Option to select the bitwidth to use for float bias tensor, either 32(default) \
                or 16",
        )
        converter_args.add_argument(
            "--quantization_overrides",
            required=False,
            default=None,
            type=str,
            help="Path to quantization overrides json file.",
        )
        converter_args.add_argument(
            "--onnx_define_symbol",
            default=None,
            nargs=2,
            action="append",
            required=False,
            metavar=("SYMBOL", "VALUE"),
            help="Option to override specific input dimension symbols.",
        )
        converter_args.add_argument(
            "--onnx_defer_loading",
            default=False,
            action="store_true",
            required=False,
            help="Option to have the model not load weights. "
            "If False, the model will be loaded eagerly.",
        )
        converter_args.add_argument(
            "--enable_framework_trace",
            default=False,
            action="store_true",
            required=False,
            help="Use this option to enable converter to trace the o/p tensor change information.",
        )
        converter_args.add_argument(
            "--op_package_config",
            default=[],
            required=False,
            nargs="+",
            type=str,
            help="Absolute paths to Qnn Op Package XML configuration file that "
            "contains user defined custom operations."
            "Note: Only one of: {'op_package_config', 'package_name'} can be specified.",
        )
        converter_args.add_argument(
            "--converter_op_package_lib",
            default=[],
            type=str,
            help="Absolute path to converter op package library compiled by the OpPackage "
            "generator. Must be separated by a comma for multiple package libraries. "
            "Note: Libraries must follow the same order as the xml files. "
            "E.g.1: --converter_op_package_lib absolute_path_to/libExample.so "
            "E.g.2: --converter_op_package_lib "
            "absolute_path_to/libExample1.so,absolute_path_to/libExample2.so",
        )
        converter_args.add_argument(
            "--package_name",
            default="",
            required=False,
            type=str,
            help="A global package name to be used for each node in the Model.cpp file. "
            "Defaults to Qnn header defined package name. "
            "Note: Only one of: {'op_package_config', 'package_name'} can be specified.",
        )
        quantizer_args.add_argument(
            "--calibration_input_list",
            type=str,
            required=False,
            default=None,
            help="Path to the inputs list text file to run quantization(used with qairt-quantizer)",
        )
        quantizer_args.add_argument(
            "--bias_bitwidth",
            type=int,
            required=False,
            default=8,
            choices=[8, 32],
            help="Option to select the bitwidth to use when quantizing the bias. default 8",
        )
        quantizer_args.add_argument(
            "--act_bitwidth",
            type=int,
            required=False,
            default=8,
            choices=[8, 16],
            help="Option to select the bitwidth to use when quantizing the activations. default 8",
        )
        quantizer_args.add_argument(
            "--weights_bitwidth",
            type=int,
            required=False,
            default=8,
            choices=[8, 4],
            help="Option to select the bitwidth to use when quantizing the weights. default 8",
        )
        quantizer_args.add_argument(
            "--quantizer_float_bitwidth",
            type=int,
            required=False,
            default=32,
            choices=[32, 16],
            help="Use this option to select the bitwidth to use for float tensors, \
                either 32 (default) or 16.",
        )
        quantizer_args.add_argument(
            "--act_quantizer_calibration",
            type=str.lower,
            required=False,
            default="min-max",
            choices=["min-max", "sqnr", "entropy", "mse", "percentile"],
            help="Specify which quantization calibration method to use for activations. "
            "Supported values: min-max (default), sqnr, entropy, mse, percentile. "
            "This option can be paired with --act_quantizer_schema to override the "
            "quantization schema to use for activations otherwise the default "
            "schema (asymmetric) will be used.",
        )
        quantizer_args.add_argument(
            "--param_quantizer_calibration",
            type=str.lower,
            required=False,
            default="min-max",
            choices=["min-max", "sqnr", "entropy", "mse", "percentile"],
            help="Specify which quantization calibration method to use for parameters. "
            "Supported values: min-max (default), sqnr, entropy, mse, percentile. "
            "This option can be paired with --act_quantizer_schema to override the "
            "quantization schema to use for activations otherwise the default "
            "schema (asymmetric) will be used.",
        )
        quantizer_args.add_argument(
            "--act_quantizer_schema",
            type=str.lower,
            required=False,
            default="asymmetric",
            choices=["asymmetric", "symmetric", "unsignedsymmetric"],
            help="Specify which quantization schema to use for activations. \
                Note: Default is asymmetric.",
        )
        quantizer_args.add_argument(
            "--param_quantizer_schema",
            type=str.lower,
            required=False,
            default="asymmetric",
            choices=["asymmetric", "symmetric", "unsignedsymmetric"],
            help="Specify which quantization schema to use for parameters. \
                Note: Default is asymmetric.",
        )
        quantizer_args.add_argument(
            "--percentile_calibration_value",
            type=float,
            required=False,
            default=99.99,
            help="Value must lie between 90 and 100. Default is 99.99",
        )
        quantizer_args.add_argument(
            "--use_per_channel_quantization",
            action="store_true",
            default=False,
            help="Use per-channel quantization for convolution-based op weights. \
                Note: This will replace built-in model QAT encodings when used for a given weight.",
        )

        quantizer_args.add_argument(
            "--use_per_row_quantization",
            action="store_true",
            default=False,
            help="Use this option to enable rowwise quantization of Matmul and FullyConnected ops.",
        )

        quantizer_args.add_argument(
            "--float_fallback",
            action="store_true",
            default=False,
            help="Use this option to enable fallback to floating point (FP) instead of fixed point."
            "This option can be paired with --quantizer_float_bitwidth to indicate the bitwidth for"
            "FP (by default 32). If this option is enabled, then input list must "
            "not be provided and --ignore_encodings must not be provided. "
            "The external quantization encodings (encoding file/FakeQuant encodings) "
            "might be missing quantization parameters for some interim tensors. "
            "First it will try to fill the gaps by propagating across math-invariant "
            "functions. If the quantization parameters are still missing, "
            "then it will apply fallback to nodes to floating point.",
        )
        quantizer_args.add_argument(
            "--quantization_algorithms",
            required=False,
            default=[],
            type=str,
            nargs="+",
            help="Use this option to select quantization algorithms. Usage is: \
                --quantization_algorithms <algo_name1> ... ",
        )

        quantizer_args.add_argument(
            "--restrict_quantization_steps",
            required=False,
            default=[],
            type=str,
            help="Specifies the number of steps to use for computing"
            'quantization encodings E.g.--restrict_quantization_steps "-0x80 0x7F" indicates an \
            example 8 bit range,',
        )
        quantizer_args.add_argument(
            "--dump_encodings_json",
            required=False,
            default=False,
            action="store_true",
            help="Dump encoding of all the tensors in a json file",
        )

        quantizer_args.add_argument(
            "--ignore_encodings",
            required=False,
            default=False,
            action="store_true",
            help="Use only quantizer generated encodings, "
            "ignoring any user or model provided encodings.",
        )
        quantizer_args.add_argument(
            "--op_package_lib",
            type=str,
            default=[],
            required=False,
            help="Use this argument to pass an op package library for quantization. "
            "Must be in the form <op_package_lib_path:interfaceProviderName> and "
            "be separated by a comma for multiple package libs",
        )

        net_run_args.add_argument(
            "--perf_profile",
            type=str.lower,
            required=False,
            default="balanced",
            choices=[
                "low_balanced",
                "balanced",
                "default",
                "high_performance",
                "sustained_high_performance",
                "burst",
                "low_power_saver",
                "power_saver",
                "high_power_saver",
                "extreme_power_saver",
                "system_settings",
            ],
            help='Specifies performance profile to set. Valid settings are "low_balanced" ,'
            '"balanced", "default", high_performance" ,"sustained_high_performance", "burst", '
            '"low_power_saver", "power_saver", "high_power_saver", "extreme_power_saver", and '
            '"system_settings". Note: perf_profile argument is now deprecated for '
            "HTP backend, user can specify performance profile through "
            "backend extension config now.",
        )
        net_run_args.add_argument(
            "--profiling_level",
            type=str.lower,
            required=False,
            default=None,
            help="Enables profiling and sets its level. "
            'For QNN executor, valid settings are "basic", "detailed" and "client" '
            "Default is detailed.",
        )
        net_run_args.add_argument(
            "--input_list",
            type=str,
            required=False,
            help="Path to the input list text file to run inference(used with net-run). "
            "Note: When having multiple entries in text file, in order to save "
            "memory and time.",
        )
        net_run_args.add_argument(
            "--netrun_backend_extension_config",
            type=str,
            required=False,
            default=None,
            help="Path to config to be used with qnn-net-run",
        )

        offline_prepare_args.add_argument(
            "--offline_prepare_backend_extension_config",
            type=str,
            required=False,
            default=None,
            help="Path to config to be used with qnn-context-binary-generator.",
        )

        self.optional.add_argument(
            "--backend",
            type=str.upper,
            required=False,
            choices=[backend.value for backend in BackendType],
            default=None,
            help="Backend type for inference to be run",
        )
        self.optional.add_argument(
            "--platform",
            type=str,
            required=False,
            choices=[platform.value for platform in DevicePlatformType],
            default=None,
            help="The type of device platform to be used for inference",
        )
        self.optional.add_argument(
            "--offline_prepare",
            action="store_true",
            required=False,
            default=False,
            help=" Boolean to indicate offline preapre of the graph",
        )
        self.optional.add_argument(
            "--working_directory",
            type=str,
            required=False,
            default=None,
            help="Path to the directory to store the output result",
        )
        self.optional.add_argument(
            "--device_id",
            type=str,
            required=False,
            default=None,
            help="The serial number of the device to use. If not available, "
            "the first in a list of queried devices will be used for inference.",
        )
        self.optional.add_argument(
            "--log_level",
            type=str.upper,
            required=False,
            default="INFO",
            choices=["ERROR", "WARN", "INFO", "DEBUG", "VERBOSE"],
            help="Enable verbose logging.",
        )
        self.optional.add_argument(
            "--op_packages",
            type=str,
            required=False,
            default=[],
            help="Provide a comma separated list of op package and interface providers to "
            "register during graph preparation."
            "Usage: op_package_path:interface_provider[,op_package_path:interface_provider...]",
        )
        self.optional.add_argument(
            "--soc_model",
            type=str,
            required=False,
            default="",
            help="Option to specify the SOC on which the model needs to run. "
            "This can be found from SOC info of the device and it starts with strings "
            "such as SDM, SM, QCS, IPQ, SA, QC, SC, SXR, SSG, STP, QRB, or AIC.",
        )

    def _get_converter_args(self, args: argparse.Namespace) -> argparse.Namespace:
        """Validate and build converter args from the parsed args

        Args:
            args (argparse.Namespace): The parsed arguments.

        Returns:
            ConverterInputArguments: The converter arguments.
        """
        # Validate desired input shape argument
        if args.desired_input_shape:
            tensor_list = []
            for tensor in args.desired_input_shape:
                input_shape = ""
                input_datatype = "float32"
                input_layout = None
                if len(tensor) > 1:
                    input_shape = tensor[1]
                if len(tensor) > 2:
                    input_datatype = tensor[2]
                if len(tensor) > 3:
                    input_layout = tensor[3]
                tensor_list.append(
                    InputTensorConfig(
                        name=tensor[0],
                        source_model_input_shape=input_shape,
                        source_model_input_datatype=input_datatype,
                        source_model_input_layout=input_layout,
                    )
                )
            args.desired_input_shape = tensor_list

        # Update output tensor argument
        if args.output_tensor:
            output_tensors = []
            for tensor in args.output_tensor:
                output_tensors.append(OutputTensorConfig(name=tensor))
            args.output_tensor = output_tensors

        # Update converter_op_package argument
        if args.converter_op_package_lib:
            args.converter_op_package_lib = args.converter_op_package_lib.split(",")

        # Update onnx_define_symbol argument
        if args.onnx_define_symbol:
            symbols = []
            for symbol in args.onnx_define_symbol:
                symbols.append((symbol[0], int(symbol[1])))
            args.onnx_define_symbol = symbols
        # Define converter_args by including all the arguments that are passed to the converter
        converter_args = ConverterInputArguments(
            input_tensors=args.desired_input_shape,
            output_tensors=args.output_tensor,
            float_bitwidth=args.converter_float_bitwidth,
            float_bias_bitwidth=args.float_bias_bitwidth,
            quantization_overrides=args.quantization_overrides,
            onnx_define_symbol=args.onnx_define_symbol,
            onnx_defer_loading=args.onnx_defer_loading,
            enable_framework_trace=args.enable_framework_trace,
            op_package_config=args.op_package_config,
            converter_op_package_lib=args.converter_op_package_lib,
            package_name=args.package_name,
        )
        return converter_args

    def _get_quantizer_args(self, args: argparse.Namespace) -> argparse.Namespace:
        """Validate and build quantizer arguments from the parsed arguments.

        Args:
            args (argparse.Namespace): The parsed arguments.

        Returns:
            QuantizerInputArguments: The quantizer arguments.
        """
        # In case of calibration input list or float_fallback passed, create quantizer_args
        # to be passed to the quantizer
        if args.calibration_input_list or args.float_fallback:
            # Update restrict_quantizaiton_steps argument
            if args.restrict_quantization_steps:
                args.restrict_quantization_steps = args.restrict_quantization_steps.split()
            # Update op_package_lib argument
            if args.op_package_lib:
                args.op_package_lib = args.op_package_lib.split(",")
            quantizer_args = QuantizerInputArguments(
                input_list=args.calibration_input_list,
                bias_bitwidth=args.bias_bitwidth,
                act_bitwidth=args.act_bitwidth,
                weights_bitwidth=args.weights_bitwidth,
                float_bitwidth=args.quantizer_float_bitwidth,
                act_quantizer_calibration=args.act_quantizer_calibration,
                param_quantizer_calibration=args.param_quantizer_calibration,
                act_quantizer_schema=args.act_quantizer_schema,
                param_quantizer_schema=args.param_quantizer_schema,
                percentile_calibration_value=args.percentile_calibration_value,
                use_per_channel_quantization=args.use_per_channel_quantization,
                use_per_row_quantization=args.use_per_row_quantization,
                float_fallback=args.float_fallback,
                algorithms=args.quantization_algorithms,
                restrict_quantization_steps=args.restrict_quantization_steps,
                dump_encoding_json=args.dump_encodings_json,
                ignore_encodings=args.ignore_encodings,
                op_package_lib=args.op_package_lib,
            )
            return quantizer_args
        else:
            return None

    def _get_context_bin_args(self, args: argparse.Namespace) -> argparse.Namespace:
        """Validate and build context bin arguments from the parsed arguments.

        Args:
            args (argparse.Namespace): The parsed arguments.

        Returns:
            GenerateConfig: The context bin arguments.
        """
        # In case of offline_prepare, create context_bin_args to pass to context_bin_gen
        if args.offline_prepare:
            context_bin_args = GenerateConfig(
                profiling_level=args.profiling_level,
                enable_intermediate_outputs=True,
                op_packages=args.op_packages,
            )
            return context_bin_args
        else:
            return None

    def _get_net_run_args(self, args: argparse.Namespace) -> argparse.Namespace:
        """Validate and build net run arguments from the parsed arguments.

        Args:
            args (argparse.Namespace): The parsed arguments.

        Returns:
            NetRunnerInputArguments: The net run arguments.
        """
        # In case of input_list provided, create net_run_args to pass to net_runner
        if args.input_list:
            net_run_args = NetRunnerInputArguments(
                perf_profile=args.perf_profile,
                profiling_level=args.profiling_level,
                debug=True,
                op_packages=args.op_packages,
            )
            return net_run_args
        else:
            return None

    def _verify_and_update_parsed_args(self, args: argparse.Namespace) -> argparse.Namespace:
        """Validates and updates parsed args
        Args:
            args (argparse.Namespace): parsed arguments
        Returns:
            argparse.Namespace: Verified and updated arguments
        """
        args = super()._verify_and_update_parsed_args(args)

        args.converter_args = self._get_converter_args(args)
        args.quantizer_args = self._get_quantizer_args(args)

        # Update op_package argument which is passed to context_bin and net_run args
        if args.op_packages:
            op_packages = []
            for op_package in args.op_packages.split(","):
                if ":" not in op_package:
                    raise ValueError(
                        f"Invalid op_package format: {op_package}. Expected format: 'package_path:interface_provider'"
                    )
                package_path, interface = op_package.rsplit(":", 1)
                op_packages.append(
                    OpPackageIdentifier(package_path=package_path, interface_provider=interface)
                )
            args.op_packages = op_packages

        args.context_bin_args = self._get_context_bin_args(args)
        args.net_run_args = self._get_net_run_args(args)

        # In case of offline_prepare, set debug in net_run as false
        # as enable_intermediate_outputs is set in context_bin_gen args
        if args.net_run_args and args.offline_prepare:
            args.net_run_args.debug = False

        # Create RemoteHostDetails object and pass the serial_id argument
        args.remote_host_details = RemoteHostDetails(
            identifier=RemoteDeviceIdentifier(serial_id=args.device_id)
        )

        # Update backend to BackendType object
        if args.backend:
            args.backend = BackendType(args.backend)

        # Update platform to DevicePlatformType object
        if args.platform:
            args.platform = DevicePlatformType(args.platform)

        return args
