# =============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All rights reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================
import argparse

from qti.aisw.accuracy_debugger.argparser.framework_runner_parser import FrameworkRunnerParser
from qti.aisw.accuracy_debugger.argparser.inference_engine_parser import InferenceEngineParser
from qti.aisw.accuracy_debugger.common_config import NetRunnerInputArguments
from qti.aisw.accuracy_debugger.model_snooper_module import validate_quantization_params
from qti.aisw.accuracy_debugger.utils.constants import (
    Algorithm,
    supported_backends,
    supported_platforms,
)
from qti.aisw.tools.core.modules.api.definitions.common import BackendType
from qti.aisw.tools.core.modules.context_bin_gen.context_bin_gen_module import GenerateConfig
from qti.aisw.tools.core.utilities.comparators.common import COMPARATORS
from qti.aisw.tools.core.utilities.comparators.factory import get_comparator


class ModelSnooperParser(FrameworkRunnerParser, InferenceEngineParser):
    """This is a parser for Snooper tool"""

    def __init__(self, component="snooping"):
        super().__init__(component)

    def _initialize(self):
        """Create parser with snooper tool specific arguments"""
        self.optional.add_argument(
            "--algorithm",
            type=Algorithm,
            required=False,
            choices=[algo.value for algo in Algorithm],
            default=Algorithm.ONESHOT,
            help="Algorithm to use to debug the model.",
        )
        super()._initialize()
        self.required.add_argument(
            "--backend",
            type=str.upper,
            required=True,
            choices=[backend.value for backend in supported_backends],
            help="Backend type for inference to be run",
        )
        self.required.add_argument(
            "--platform",
            type=str,
            required=True,
            choices=[platform.value for platform in supported_platforms],
            help="The type of device platform to be used for inference",
        )
        self.optional.add_argument("--input_list", help=argparse.SUPPRESS)
        self.optional.add_argument(
            "--golden_reference",
            required=False,
            help="The path of directory where golden reference tensor files are saved.",
        )
        self.optional.add_argument(
            "--is_qnn_golden_reference",
            action="store_true",
            required=False,
            default=False,
            help="""Specifies that outputs passed with --golden_reference are dumped by QNN.
            This option should be used only when --golden_reference is supplied.""",
        )
        self.optional.add_argument(
            "--retain_compilation_artifacts",
            action="store_true",
            required=False,
            default=False,
            help="Flag to retain compilation artifacts.",
        )
        self.optional.add_argument(
            "--comparator",
            type=COMPARATORS,
            nargs="+",
            required=False,
            choices=[comp.value for comp in COMPARATORS],
            default=[COMPARATORS.MSE],
            help="Comparator to use to compare tensors. For multiple comparators, "
            "specify as follows: --comparator mse std",
        )
        self.optional.add_argument(
            "--offline_prepare",
            action="store_true",
            required=False,
            default=None,
            help=" Boolean to indicate offline preapre of the graph",
        )
        self.optional.add_argument(
            "--debug_subgraph_inputs",
            type=str,
            default=None,
            required=False,
            help="pass a comma separated inputs for the subgraph which is to be debugged."
            "Currently support is limited to layerwise and cumulative algorithms.",
        )
        self.optional.add_argument(
            "--debug_subgraph_outputs",
            type=str,
            default=None,
            required=False,
            help="pass a comma separated outputs for the subgraph which is to be debugged."
            "Currently support is limited to layerwise and cumulative algorithms.",
        )

        # Supress arguments related to custom ops. Enabled them when custom op supported enabled in
        # snooping algorithms.
        self.optional.add_argument("--op_package_config", default=[], help=argparse.SUPPRESS)
        self.optional.add_argument("--converter_op_package_lib", default=[], help=argparse.SUPPRESS)
        self.optional.add_argument("--package_name", default="", type=str, help=argparse.SUPPRESS)
        self.optional.add_argument("--op_package_lib", default=[], help=argparse.SUPPRESS)
        self.optional.add_argument("--op_packages", default=[], help=argparse.SUPPRESS)

    def _get_context_bin_args(self, args: argparse.Namespace) -> argparse.Namespace:
        """Validate and build context bin arguments from the parsed arguments.

        Args:
            args (argparse.Namespace): The parsed arguments.

        Returns:
            GenerateConfig: The context bin arguments.
        """
        # In case of offline_prepare, create context_bin_args to pass to context_bin_gen
        if args.offline_prepare:
            context_bin_args = GenerateConfig(
                profiling_level=args.profiling_level,
                op_packages=args.op_packages,
            )
            return context_bin_args
        else:
            return None

    def _get_net_run_args(self, args: argparse.Namespace) -> argparse.Namespace:
        """Validate and build net run arguments from the parsed arguments.

        Args:
            args (argparse.Namespace): The parsed arguments.

        Returns:
            NetRunnerInputArguments: The net run arguments.
        """
        net_run_args = NetRunnerInputArguments(
            perf_profile=args.perf_profile,
            profiling_level=args.profiling_level,
            op_packages=args.op_packages,
        )
        return net_run_args

    def _verify_and_update_parsed_args(self, args):
        """Validates and updates parsed args
        Args:
            args (argparse.Namespace): parsed arguments
        Returns:
            argparse.Namespace: Verified and updated arguments
            - Sets offline_prepare=True if backend supports it
            - Sets float_fallback=True if quantization_overrides provided without calibration_input_list
        """
        backend = BackendType(args.backend)
        # Enable offline prepare if backend supports.
        if args.offline_prepare is None and backend in BackendType.offline_preparable_backends():
            args.offline_prepare = True

        # Either quantization_overrides or calibration_input_list should be provided
        validate_quantization_params(
            args.calibration_input_list, args.quantization_overrides, args.algorithm, backend
        )

        # If user has provided quantization_overrides and no calibration data then
        # make float_fallback True
        if args.quantization_overrides and not args.calibration_input_list:
            args.float_fallback = True

        args = super()._verify_and_update_parsed_args(args)

        # Update comparator by getting the relevant comparator class
        args.comparator = [get_comparator(comp) for comp in args.comparator]

        if args.is_qnn_golden_reference and args.golden_reference is None:
            raise Exception(
                "--is_qnn_golden_reference is allowed only when --golden_reference is supplied."
            )

        # Parse debug_subgraph_input and outputs
        if args.debug_subgraph_inputs:
            args.debug_subgraph_inputs = args.debug_subgraph_inputs.split(",")
        if args.debug_subgraph_outputs:
            args.debug_subgraph_outputs = args.debug_subgraph_outputs.split(",")
        return args
