# =============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All rights reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================
import argparse
from os import path

from qti.aisw.accuracy_debugger.argparser.parser import Parser
from qti.aisw.accuracy_debugger.common_config import InputSample
from qti.aisw.accuracy_debugger.framework_runner.frameworks.onnx_framework import (
    CustomOnnxFramework,
)


class FrameworkRunnerParser(Parser):
    """This is a parser for framework runner tool."""

    def __init__(self, component="framework_runner"):
        super().__init__(component)

    def _initialize(self):
        """Create parser with framework runner tool specific arguments"""
        super()._initialize()
        self.required.add_argument(
            "-m", "--input_model", type=str, required=True, help="path to the model file"
        )
        self.required.add_argument(
            "--input_sample",
            required=True,
            type=str,
            help="Path to text file containing input sample.",
        )
        self.optional.add_argument(
            "--working_directory",
            type=str,
            required=False,
            default=None,
            help="Path to working directory. If not specified a directory with name \
                working_directory will be created in the current directory.",
        )
        self.optional.add_argument(
            "-o",
            "--output_tensor",
            type=str,
            required=False,
            action="append",
            help="Name of the graph's specified output tensor(s).",
        )
        self.optional.add_argument(
            "--onnx_define_symbol",
            default=None,
            nargs=2,
            action="append",
            required=False,
            metavar=("SYMBOL", "VALUE"),
            help="Option to override specific input dimension symbols.",
        )
        self.optional.add_argument(
            "--log_level",
            type=str,
            required=False,
            default="info",
            choices=["info", "debug", "warning", "error"],
            help="Log level. Default is info",
        )

    def _verify_and_update_parsed_args(self, args: argparse.Namespace) -> argparse.Namespace:
        """Validates parsed arguments

        Args:
            args (argparse.Namespace): parsed arguments
        Returns:
            argparse.Namespace: Verified and updated arguments
        """
        args = super()._verify_and_update_parsed_args(args)
        if not path.exists(args.input_model):
            raise argparse.ArgumentTypeError(
                "Input model path, {}, doesn't exists.".format(args.input_model)
            )

        if not path.exists(args.input_sample):
            raise argparse.ArgumentTypeError(
                "Input sample file, {}, doesn't exists.".format(args.input_sample)
            )
        args.input_sample = self._parse_input_sample(
            str(args.input_model), args.input_sample, args.onnx_define_symbol
        )
        return args

    def _parse_input_sample(
        self, input_model: str, input_sample: str, onnx_symbols: list[tuple] = None
    ) -> list[InputSample]:
        """Returns InputSample object for the specified input sample file.

        Args:
            input_model (str): Model file path
            input_sample (str): Input sample file path
            onnx_symbols (list[tuple]): List of symbols and their values to override.

        Returns:
            list[InputSample]: List of input sample objects
        """
        input_sample_objs = []
        input_sample_line = ""

        with open(input_sample, "r") as file:
            for line in file:
                strip_line = line.strip()
                if strip_line:
                    input_sample_line = strip_line
                    break
        if not input_sample_line:
            raise argparse.ArgumentTypeError("Invalid input sample file supplied.")
        input_tensors = input_sample_line.split()

        # Get model input tensor details
        onnx_framework = CustomOnnxFramework(logger=None)
        model_inp_tensors = onnx_framework.get_input_tensor_details(
            model_path=input_model, onnx_symbols=onnx_symbols
        )

        # Check if the number of input tensors matches
        if len(input_tensors) != len(model_inp_tensors):
            raise argparse.ArgumentTypeError(
                f"Number of input tensors in input sample file: {len(input_tensors)} does not "
                f"match the number of input tensors in the model: {len(model_inp_tensors)}."
            )

        # Determine the format of the input tensors and create InputSample objects accordingly
        if ":=" in input_tensors[0]:
            # Format: name:=raw_file
            model_tensor_names = [tensor["name"] for tensor in model_inp_tensors]
            for tensor in input_tensors:
                name, raw_file_path = tensor.split(":=")
                if name not in model_tensor_names:
                    raise argparse.ArgumentTypeError(
                        f"Input tensor name: {name} not found in the model."
                    )

                model_inp_tensor = next((t for t in model_inp_tensors if t["name"] == name), None)
                input_sample_objs.append(
                    InputSample(
                        name=name,
                        raw_file=raw_file_path,
                        dimensions=model_inp_tensor["shape"],
                        data_type=model_inp_tensor["data_type"],
                    )
                )
        else:
            # Format: raw_file
            for i, raw_file_path in enumerate(input_tensors):
                input_sample_objs.append(
                    InputSample(
                        raw_file=raw_file_path,
                        name=model_inp_tensors[i]["name"],
                        dimensions=model_inp_tensors[i]["shape"],
                        data_type=model_inp_tensors[i]["data_type"],
                    )
                )
        return input_sample_objs
