# =============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All rights reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import logging
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional

from pydantic import Field
from qti.aisw.converters.common.backend_awareness import BackendInfo
from qti.aisw.converters.common.converter_ir.op_graph_optimizations import IROptimizations

# Converter Imports
from qti.aisw.converters.common.utils import converter_utils

# Module Imports
from qti.aisw.tools.core.modules.api import (
    AISWBaseModel,
    Module,
    ModuleSchema,
    ModuleSchemaVersion,
    expect_module_compliance,
)
from qti.aisw.tools.core.modules.converter.common import BackendInfoConfig
from qti.aisw.tools.core.modules.converter.constants import *
from qti.aisw.tools.core.utilities.qairt_logging import QAIRTLogger


class OptimizerInputConfig(AISWBaseModel):
    """Configuration model for the optimizer input."""

    # TODO: Accept DLC as input once, DLC serialization enabled for IRgraph. make Ir_graph as
    # optional.

    ir_graph: Any = Field(description="IRgraph generated by model conversion.")
    framework: str = Field(description="Source framework from which IRgraph generated.")
    backend_info: Optional[BackendInfoConfig] = Field(default=None, description="Backend information.")
    disable_batchnorm_folding: bool = Field(
        default=False, description="Disable BatchNorm folding optimization."
    )
    expand_lstm_op_structure: bool = Field(default=False, description="Enables optimization that breaks the LSTM op to equivalent math ops.")
    multi_time_steps_lstm: bool = Field(
        default=False, description="Enable multi-time-steps LSTM optimization."
    )
    multi_time_steps_gru: bool = Field(default=False, description="Enable multi-time-steps GRU optimization.")
    optimization_pass_mode: str = Field(default="ir_optimizer_mainline", description="The pass mode to use for IrOptimizer Library")


class OptimizerOutputConfig(AISWBaseModel):
    """Configuration model for the optimizer output."""

    optimized_graph: Any = Field(default=None, description="Optimized IR graph")
    optimizer_args: Dict[str, Any] = Field(default_factory=dict, description="Optimizer arguments")


class OptimizerModuleSchemaV1(ModuleSchema):
    """Schema definition for the OptimizerModule"""

    _VERSION = ModuleSchemaVersion(major=0, minor=3, patch=0)
    _BACKENDS = None
    name: Literal["OptimizerModule"] = "OptimizerModule"
    path: Path = Path(__file__)
    arguments: OptimizerInputConfig
    outputs: OptimizerOutputConfig
    backends: Optional[List[str]] = _BACKENDS
    version: ModuleSchemaVersion = _VERSION


@expect_module_compliance
class QAIRTOptimizer(Module):
    """User interface class for optimizer API"""

    _SCHEMA = OptimizerModuleSchemaV1
    _PREVIOUS_SCHEMAS = []

    def __init__(self, logger: logging.Logger = None) -> None:
        """Args:
        logger: A logger instance to be used by the Converter module
        """
        if logger:
            self._logger = QAIRTLogger.get_logger("OptimizerLogger", parent_logger=logger)
        else:
            self._logger = QAIRTLogger.get_logger(
                "OptimizerLogger", level="INFO", formatter_val="extended", handler_list=["dev_console"]
            )
        converter_utils.LOGGER = self._logger
        # Change this log level to debug Converter related issues
        converter_utils.LOG_LEVEL = logging.ERROR
        self._debug_level = QAIRTLogger.get_default_logging_level("OptimizerLogger")

    def optimize(
        self,
        config: OptimizerInputConfig,
    ) -> OptimizerOutputConfig:
        """Performs optimizations on the IRGraph contained in the config.

        Optimizations are intended to increase performance while maintaining mathematical equivalence.

        Args:
            config (OptimizerInputConfig): Configuration object containing the IR graph,
                framework details, and backend settings.

        Returns:
            OptimizerOutputConfig: Contains the optimized IR graph & optimizer arguments.

        Raises:
            Exception: If the optimization process fails, the exception is logged and re-raised.

        Example:
            ```python
            from qti.aisw.tools.core.modules.converter import QAIRTConverter, ConverterInputConfig
            converter_config = QAIRTConverter().convert(ConverterInputConfig("/path/to/model"))
            ir_graph = converter_config.ir_graph
            framework = converter_config.framework
            optimizer = QAIRTOptimizer()
            out_config = optimizer.optimize(
                OptimizerInputConfig(ir_graph=ir_graph, framework=framework)
            )
            ```
        """
        # transform args to converter namespace
        args = QAIRTOptimizer._get_namespace_args(config)
        # Add log level to args to that internal libs use same log level as API.
        args.debug = self._debug_level

        try:
            backend_info_obj = BackendInfo.get_instance(args.backend, args.soc_model)
            optimizer = IROptimizations(args)
            optimized_graph = optimizer.optimize(config.ir_graph, backend_info_obj)

        except Exception as e:
            self._logger.error("IRgraph optimization failed.")
            raise e

        return OptimizerOutputConfig(optimized_graph=optimized_graph, optimizer_args=vars(args))

    def enable_debug(self, debug_level: int, **kwargs) -> Optional[bool]:
        """Sets optimizer log level.

        Args:
            debug_level: LOGLEVEL.DEBUG enables DEBUG and higher level messages.
               LOGLEVEL.INFO enables INFO and higher level messages.
            **kwargs:

        Returns:
            bool: 'True' if debugging is enabled else return 'False'.
        """
        if debug_level < LOGLEVEL.INFO:
            return False
        self._debug_level = debug_level
        converter_utils.setup_logging(self._debug_level)
        return True

    @property
    def _schema(self):
        return self._SCHEMA

    def get_logger(self) -> Any:
        return self._logger

    def properties(self) -> Dict[str, Any]:
        return self._schema.model_json_schema()

    @staticmethod
    def _get_namespace_args(config: OptimizerInputConfig) -> Namespace:
        """This method accepts optimizer input config and returns arguments in a namespace object.
        1. Converts argument names to optimizer internal names.
        2. Sets default values for suppressed arguments.

        Args:
            config: Optimizer input arguments config.

        Returns:
            Return namespace containing arguments.
        """
        option_dict = config.model_dump()
        _ = option_dict.pop("ir_graph")
        framework = option_dict.pop("framework")

        backend_info = option_dict.pop("backend_info", None)
        if backend_info is None:
            backend_info = BackendInfoConfig()
        else:
            backend_info = BackendInfoConfig(**backend_info)

        option_dict["backend"] = backend_info.backend
        option_dict["soc_model"] = backend_info.soc_model

        args = Namespace(**option_dict)
        args.align_matmul_ranks = False
        args.dumpIR = False
        args.disable_match_lstms = False
        args.squash_box_decoder = False
        args.match_caffe_ssd_to_tf = False
        args.adjust_nms_features_dims = False
        args.extract_color_transform = False
        args.perform_axes_to_spatial_first_order = False
        args.perform_layout_transformation = False
        args.preprocess_roi_pool_inputs = False
        args.unroll_lstm_time_steps = False
        args.unroll_gru_time_steps = False
        args.expand_gru_op_structure = False
        args.force_prune_cast_ops = False
        args.inject_cast_for_gather = False
        args.use_convert_quantization_nodes = False
        args.prepare_inputs_as_params = False
        args.handle_gather_negative_indices = False
        args.enable_match_gathernd = False
        args.expand_sparse_op_structure = False
        args.keep_disconnected_nodes = False
        args.apply_masked_softmax = "uncompressed"
        args.packed_masked_softmax_inputs = "uncompressed"
        args.packed_max_seq = 1
        args.op_package_lib = None
        args.keep_int64_inputs = False
        args.enable_match_topk = False
        args.preserve_onnx_output_order = False

        QAIRTOptimizer._set_optimization_args(args, config.framework)

        if backend_info.backend == "AIC" and args.align_matmul_ranks and framework == OnnxFrameworkInfo.name:
            args.align_matmul_ranks = False

        return args

    @staticmethod
    def _set_optimization_args(args: Namespace, framework: str) -> None:
        # TODO: Align optimizations for all frameworks
        if framework == OnnxFrameworkInfo.name:
            args.expand_gru_op_structure = True
            args.unroll_gru_time_steps = True
            args.expand_sparse_op_structure = False

        if framework == OnnxFrameworkInfo.name or framework == PytorchFrameworkInfo.name:
            args.perform_layout_transformation = True
            args.preprocess_roi_pool_inputs = True

        if framework == OnnxFrameworkInfo.name or framework == TensorflowFrameworkInfo.name:
            args.unroll_lstm_time_steps = True
            args.align_matmul_ranks = True
            args.handle_gather_negative_indices = True

        if framework == TensorflowFrameworkInfo.name or framework == PytorchFrameworkInfo.name:
            args.match_caffe_ssd_to_tf = True

        # Enable/Disable following optimizations for onnx, tf, pytorch
        if framework != TFLiteFrameworkInfo.name:
            args.squash_box_decoder = True
            args.adjust_nms_features_dims = True
            args.extract_color_transform = True
            args.inject_cast_for_gather = True
            args.force_prune_cast_ops = False
