# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import copy
import os
import pathlib
from os import PathLike
from typing import Any, List, Optional

import onnx

from qairt.api.compiled_model import CompiledModel
from qairt.api.configs.common import BackendType
from qairt.api.converter.converter_config import CalibrationConfig, ConverterConfig
from qairt.api.transforms.model_transformer_config import (
    ARn_ContextLengthConfig,
    MhaConfig,
    ModelTransformerConfig,
    SplitModelConfig,
)
from qairt.gen_ai_api.builders.gen_ai_builder import GenAIBuilder
from qairt.gen_ai_api.builders.gen_ai_utils import (
    count_parameters,
    get_input_layout,
    get_kv_dim,
    get_pos_id_dim,
    get_positional_encodings_type,
    get_tensor_values,
    load_pretrained_config,
)
from qairt.gen_ai_api.builders.htp_mixin import HTPMixin
from qairt.gen_ai_api.configs.gen_ai_config import GenAIConfig
from qairt.gen_ai_api.containers.gen_ai_container import GenAIContainer
from qairt.gen_ai_api.containers.llm_container import LLMContainer
from qairt.modules.genie_execution.genie_config import PositionalEncoding, PositionalEncodingType
from qairt.modules.lora.lora_config import (
    LoraBuilderInputConfig,
    LoraBuilderOutputConfig,
    get_adapter_count_by_use_case,
)
from qairt.utils.loggers import get_logger
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.onnx_model import ExportedFiles

_logger = get_logger(name="GenAIBuilder.HTP")


class GenAIBuilderHTP(GenAIBuilder, HTPMixin):
    """
    GenAIBuilderHTP builds a GenAIContainer object for the HTP backend.

    """

    def __init__(
        self,
        framework_model_path: str | PathLike,
        config: GenAIConfig,
        cache_dir: Optional[str | PathLike] = None,
        enable_weight_sharing: bool = True,
    ):
        """
        Initializes a GenAIBuilderHTP instance.

        This method initializes a GenAIBuilderHTP instance with the provided model path and configuration.
        It sets up the necessary attributes for the builder, including the transformation
        configuration, and compilation options.

        Args:
            framework_model_path (str | PathLike): The path to the model to be prepared.
            config (GenAIConfig): The LLM configuration object.
            cache_dir (str | PathLike), Optional: user-provided root directory where artifacts
                may be stored and referenced in subsequent builder operations.
            enable_weight_sharing (bool): Flag to enable weight sharing during model compilation.

        Returns:
            None
        """

        GenAIBuilder.__init__(self, framework_model_path, config, BackendType.HTP)
        self._confirm_framework_model_path()

        self._validate_config()

        self._enable_weight_sharing = enable_weight_sharing

        params = count_parameters(self.framework_model_path)
        num_splits = 1 + params // (1024 * 1024 * 1024 * 2)
        split_lm_head = True
        split_embedding = True
        if config.n_embd:
            embedding_dim = config.n_vocab * config.n_embd
            split_lm_head = self._should_split_embedding(embedding_dim)
            split_embedding = self._should_split_embedding(embedding_dim)
        if split_lm_head:
            num_splits = num_splits + 1
        if split_embedding:
            num_splits = num_splits + 1

        self.set_transformation_options(
            config=ModelTransformerConfig(
                arn_cl_options=ARn_ContextLengthConfig(auto_regression_number=[1, 128]),
                split_model=SplitModelConfig(
                    num_splits=num_splits, split_lm_head=split_lm_head, split_embedding=split_embedding
                ),
            ),
        )

        HTPMixin.__init__(self, cache_dir=cache_dir, enable_weight_sharing=enable_weight_sharing)

        self.set_conversion_options(ConverterConfig(), CalibrationConfig(act_precision=16, bias_precision=32))
        self._encodings_path: Optional[str | PathLike] = None
        self._lora_config: Optional[LoraBuilderInputConfig] = None

    @property
    def enable_weight_sharing(self) -> bool:
        """
        Getter for enable_weight_sharing.

        Returns:
            bool: Whether weight sharing is enabled.
        """
        return self._enable_weight_sharing

    @enable_weight_sharing.setter
    def enable_weight_sharing(self, value: bool):
        """
        Setter for enable_weight_sharing.

        Args:
            value (bool): Whether to enable weight sharing.
        """
        self._enable_weight_sharing = value
        _logger.debug(f"Enable weight sharing set to: {value}")

        # Reset AR config to [1, 128] when weight sharing is updated to True
        if (
            value
            and len(
                self._transformation_config.model_transformer_config.arn_cl_options.auto_regression_number
            )
            < 2
        ):
            self._transformation_config.model_transformer_config.arn_cl_options.auto_regression_number = [
                1,
                128,
            ]
            _logger.debug("AR config reset to [1, 128].")

    def _validate_config(self):
        onnxmodel = onnx.load(self.framework_model_path, load_external_data=False)
        cl, arn = get_tensor_values(onnxmodel)
        if self.config.context_length != cl:
            self.config.context_length = cl
        try:
            kv_dim = get_kv_dim(onnxmodel)
            self.config.kv_dim = kv_dim
        except KeyError:
            pass
        try:
            positional_encoding_type = get_positional_encodings_type(onnxmodel, arn, cl)
            self.config.positional_encoding = PositionalEncoding(type=positional_encoding_type)
            self.config.positional_encoding.rope_dim = None
            self.config.positional_encoding.rope_theta = None
            if positional_encoding_type == PositionalEncodingType.ROPE:
                try:
                    pos_id_dim = get_pos_id_dim(onnxmodel)
                    self.config.positional_encoding.rope_dim = pos_id_dim
                    if self.config.rope_theta:
                        self.config.positional_encoding.rope_theta = self.config.rope_theta
                except KeyError:
                    pass

        except KeyError:
            pass

    @classmethod
    def from_pretrained(
        cls, pretrained_model_path: str | os.PathLike, cache_root: Optional[str | os.PathLike]
    ) -> "GenAIBuilderHTP":
        """
        This method creates a GenAIBuilderHTP instance from a pretrained model by loading the model's
        configuration and creating a new instance with the loaded configuration.

        Args:
            pretrained_model_path (str | PathLike): The path to the pretrained model.
            cache_root (str | PathLike) Optional: user-provided root directory where artifacts
                may be stored and referenced in subsequent builder operations.

        Returns:
            GenAIBuilderHTP: An instance of GenAIBuilderHTP.

        Raises:
            FileNotFoundError: If the pretrained model path does not exist.

        """
        if not os.path.exists(pretrained_model_path):
            raise FileNotFoundError(f"Pretrained model path '{pretrained_model_path}' does not exist")
        pretrained_model_path_dir = pretrained_model_path
        if os.path.isfile(pretrained_model_path):
            pretrained_model_path_dir = os.path.dirname(pretrained_model_path)

        config = load_pretrained_config(pretrained_model_path_dir)

        gen_ai_config = cls._create_config_from_pretrained(config)

        builder = cls(pretrained_model_path, gen_ai_config, cache_root)

        return builder

    def _should_split_embedding(self, embedding_size: int) -> bool:
        return embedding_size >= 1024 * 1024 * 1024

    @property
    def encodings_path(self) -> Optional[str | os.PathLike]:
        """
        Gets the path to the encodings.

        Returns:
            Optional[str | os.PathLike]: The path to the encodings, or None if not set.
        """
        return self._encodings_path

    @encodings_path.setter
    def encodings_path(self, encodings_path: Optional[str | os.PathLike] = None):
        """
        Sets the path to the encodings.

        Args:
            encodings_path (Optional[str | os.PathLike], optional): The path to the encodings. Defaults to None.
        """
        self._encodings_path = encodings_path

    @property
    def lora_config(self) -> Optional[LoraBuilderInputConfig]:
        """
        Gets the lora config.

        Returns:
            Optional[LoraBuilderInputConfig]: The lora_config, or None if not set.
        """
        return self._lora_config

    @lora_config.setter
    def lora_config(self, lora_config: Optional[LoraBuilderInputConfig] = None):
        """
        Sets the lora_config which can be a LoraConfig class object or path to lora_config file.

        Args:
            lora_config (Optional[LoraBuilderInputConfig]): The lora_config. Defaults to None.
        """
        num_splits = self._transformation_config.model_transformer_config.split_model.num_splits
        if not self._transformation_config.model_transformer_config.split_model.split_embedding:
            self._transformation_config.model_transformer_config.split_model.split_embedding = True
            num_splits = num_splits + 1
        if not self._transformation_config.model_transformer_config.split_model.split_lm_head:
            self._transformation_config.model_transformer_config.split_model.split_lm_head = True
            num_splits = num_splits + 1
        self._transformation_config.model_transformer_config.split_model.num_splits = num_splits
        # enable MHA v2 (doesn't seem to work with v1)
        if not self._transformation_config.model_transformer_config.mha_config:
            self._transformation_config.model_transformer_config.mha_config = MhaConfig()
        self._lora_config = lora_config

    def _assert_configurations_are_valid(self):
        """
        Validates the builder's configurations.

        For compilation: there must be at least one target DSP architecture.
        """
        if (
            not self._compilation_config
            or not self._compilation_config.device_custom_configs
            or not self._compilation_config.device_custom_configs[0].dsp_arch
        ):
            raise ValueError(
                "Target DSP architecture required.  Please set_compilation_options to "
                "include soc_details or add a device_custom_config to set the target DSP architecture."
            )
        for c in self._compilation_config.device_custom_configs:
            if c.dsp_arch == "v0":
                raise ValueError("Target DSP architecture undetected or unknown.")

    def _confirm_framework_model_path(self):
        """
        If the user has provided a directory instead of a model, attempt to find an onnx model
        within that directory.  If multiple models are found, selection is arbitrary.
        """
        if self.framework_model_path.is_dir():
            matches = list(self.framework_model_path.glob("*.onnx"))
            if len(matches) == 0:
                raise FileNotFoundError(f"No onnx models found at {self.framework_model_path}")
            self.framework_model_path = matches[0]
            if len(matches) > 1:
                _logger.warning(
                    f"More than one model found.  Proceeding with {self.framework_model_path.name}."
                )

    def _confirm_encodings(self):
        """
        If the user has not explicitly provided an encodings path, attempt to find encodings within the
        same directory as the model, preferring (model_name).encodings if present, and then any
        .encodings file if there is exactly one.
        """
        if not self.encodings_path:
            base_name = pathlib.Path(self.framework_model_path.name).stem
            possible_encodings_location = self.framework_model_path.parent / f"{base_name}.encodings"
            if possible_encodings_location.exists():
                _logger.warning(f"Using implicit encodings: {possible_encodings_location}")
                self.encodings_path = possible_encodings_location
            else:
                matches = list(self.framework_model_path.parent.glob("*.encodings"))
                if len(matches) == 1:
                    _logger.warning(f"Using implicit encodings: {matches[0]}")
                    self.encodings_path = matches[0]
                elif len(matches) > 1:
                    raise ValueError(
                        "Multiple/ambiguous encodings found.  Please set the builder encodings_path directly."
                    )
                else:
                    if self._calibration_config and not self._calibration_config.dataset:
                        raise ValueError("GenAIBuilder requires either calibration or encodings data.")
                    _logger.warning("Proceeding without encodings, but using calibration dataset")

    def build(self) -> GenAIContainer:
        """
        Builds a GenAIContainer instance from the source framework model. The build process consumes the source
        framework model and its pre-computed quantization encodings, and performs the following steps:

         - For a LoRA use case, builds a max rank, concatenated LoRA configuration graph

         - Apply transformations to the model (necessary to run on HTP). These transformations improve
           model performance while executing on QAIRT. This may include:

           - Splitting the model based on the number of parameters

           - Additional transformations on each split such as Multi Head Attention to Single Head Attention.

         - For each split, the builder will:

          - Convert the model into a QAIRT model. This may include quantization using encodings, and/or calibration if sample inputs are provided.

          - Compile the model ahead-of-time into a compiled model instance.

        Examples:

            .. code-block:: python

                from qairt.gen_ai_api.gen_ai_builder_factory import GenAIBuilderFactory

                # Create a gen AI builder instance and set the appropriate targets
                gen_ai_builder: GenAIBuilder = GenAIBuilderFactory.create("model_dir/model.onnx", "HTP", cache_root=CACHE_ROOT)
                gen_ai_builder.set_targets([f"chipset:SC8380XP"])

                # Build the model
                gen_ai_container = gen_ai_builder.build()


        Returns:
            GenAIContainer: An instance of GenAIContainer, which can be used to create an executor to execute the model, or saved
            to disk as a directory of artifacts generated during build.

        """

        self._assert_configurations_are_valid()

        # TODO: if it's a pytorch directory, load the model and export to onnx
        # see AISW-129859

        if not os.getenv("QAIRT_TMP_DIR"):
            _logger.warning(
                "QAIRT_TMP_DIR is not set. "
                "Be aware that building large models may require large amounts of temporary space. "
                "Set QAIRT_TMP_DIR to avoid default temporary space usage."
            )

        self._confirm_framework_model_path()

        self._confirm_encodings()

        lora_output_config = None

        if self._lora_config:
            self.config.alpha_tensor_name = self._lora_config.alpha_tensor_name
            self.config.adapter_count_by_use_case = get_adapter_count_by_use_case(self._lora_config)

        if self._lora_config and self._lora_config.create_lora_graph:
            # These return parameters will be used by transform and convert APIs
            lora_output_config = self.build_lora_graph(self._lora_config, self.path_root)
            self.framework_model_path = lora_output_config.base_model_artifacts["onnx"]
            self.encodings_path = lora_output_config.base_model_artifacts["encodings"]

        transformed_models = self.transform(
            model_path=self.framework_model_path,
            config=self._transformation_config,
            encodings_path=self.encodings_path,
            lora_output_config=lora_output_config,
        )

        prepared_models: List[CompiledModel] = []

        _logger.debug(f"Transformed models: {transformed_models}")
        for split_files in zip(*transformed_models):
            models_for_split = []
            for split_idx, exported_files in enumerate(split_files):
                self._conversion_config.input_tensor_config = get_input_layout(exported_files.onnx_path)
                convert_kwargs = {}
                if self._lora_config and exported_files.lora_importer_config:
                    convert_kwargs["quant_updatable_mode"] = self._lora_config.quant_updatable_mode
                model = self.convert(
                    exported_files, self._conversion_config, self._calibration_config, **convert_kwargs
                )
                models_for_split.append(model)

            if self._compilation_config and self._compilation_config.graph_custom_configs:
                for model in models_for_split:
                    new_graph_config: Any = self._compilation_config.graph_custom_configs[0].model_copy()
                    new_graph_config.name = model.name
                    self._compilation_config.graph_custom_configs.append(new_graph_config)
                    _logger.debug(f"Created new graph config for model: {model.name}")
                _logger.debug(
                    f"Updated graph names for split {split_idx + 1}: {[model.name for model in models_for_split]}"
                )
                if self._compilation_config.graph_custom_configs[0].name == "placeholder":
                    self._compilation_config.graph_custom_configs.pop(0)

            compiled_model = self.compile(models_for_split, self._compilation_config)
            prepared_models.append(compiled_model)

        return LLMContainer(
            prepared_models, self.config, BackendType.HTP, compile_config=self._compilation_config
        )
