# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import logging
import os
import pathlib
import sys
from enum import Enum
from typing import Optional, cast

# This import needs to be at the top of other imports so that
# required modules in transformers are modified during runtime to enable GGUF workflow.
try:
    from qti.aisw.converters.gguf_builder import gguf_builder

except ImportError as e:
    print("WARNING: Unable to import GGUF Builder")
from qairt.api.configs.common import BackendType
from qairt.api.converter.converter_config import CalibrationConfig
from qairt.gen_ai_api.builders.baichuan_builder_htp import BaichuanBuilderHTP
from qairt.gen_ai_api.builders.gen_ai_builder import GenAIBuilder
from qairt.gen_ai_api.builders.gen_ai_builder_cpu import GenAIBuilderCPU
from qairt.gen_ai_api.builders.gen_ai_builder_htp import GenAIBuilderHTP
from qairt.gen_ai_api.builders.gen_ai_utils import load_pretrained_config
from qairt.gen_ai_api.builders.indus_builder_htp import IndusBuilderHTP
from qairt.gen_ai_api.builders.jais_builder_htp import JaisBuilderHTP
from qairt.gen_ai_api.builders.llama_builder_htp import LlamaBuilderHTP
from qairt.gen_ai_api.builders.mistral_builder_htp import MistralBuilderHTP
from qairt.gen_ai_api.builders.phi_builder_htp import PhiBuilderHTP
from qairt.gen_ai_api.builders.plamo_builder_htp import PlamoBuilderHTP
from qairt.gen_ai_api.builders.qwen_builder_htp import QwenBuilderHTP

logger = logging.getLogger(__name__)


_SUPPORTED_BACKENDS = [BackendType.CPU, BackendType.HTP]


# TODO: This should be co-located with HTP Builder
class SupportedLLMs(Enum):
    """Enumeration of preconfigured builder architectures for the HTP backend"""

    LLAMA = "LlamaForCausalLM"
    BAICHUAN = "BaiChuanForCausalLM"
    PHI = "Phi3ForCausalLM"
    QWEN = "Qwen2ForCausalLM"
    MISTRAL = "MistralForCausalLM"
    JAIS = "JAISLMHeadModel"
    PLAMO = "PlamoForCausalLM"
    INDUS = "GPT2LMHeadModel"


class GenAIBuilderFactory:
    """
    Factory class to create :class:`qairt.gen_ai_api.gen_ai_builder.GenAIBuilder` instances
    """

    @classmethod
    def create(
        cls,
        pretrained_model_path: str | os.PathLike,
        backend_type: str | BackendType = BackendType.HTP,
        *,
        cache_root: Optional[str | os.PathLike] = None,
    ) -> GenAIBuilder:
        """
        Creates a GenAIBuilder instance based on the provided pretrained model path and backend type.

        This function makes the following assumptions:

         - Directory Contents: The following directory and naming structure is expected

            - For the HTP Backend:

               <pretrained_model_path>.dir
                - <model>.onnx
                - <model>.encodings (a file containing quantization overrides)
                - <model>.data (optional)
                - config.json (a Hugging Face transformers configuration for the model)
                - tokenizer.json (the corresponding configuration for the tokenizer from Hugging Face)

             If the pretrained model path is a file, the directory containing it will be used to locate
             additional required artifacts. If the pretrained model path is a directory, then the directory
             is assumed to contain a single onnx model and a single set of corresponding encodings. A warning
             will be returned if more than one model or encodings are found.

            - For the CPU Backend:

               <pretrained_model_path>.dir where the directory contains the transformers configuration,
               model weights, tokenizer and other artifacts as if the model where downloaded directly from
               Hugging Face.

         - Model Architecture:

           The builder will attempt to identify the model and provide a pre-configured instance based on the
           architecture. If HTP is requested and the model architecture is not recognized,
           then a default GenAIBuilderHTP instance will be returned, with a warning.
           See :class:`qairt.gen_ai_api.gen_ai_builder_factory.SupportedLLMs` for a list of pre-configured
           architectures.

        Args:
            pretrained_model_path (str): The path to the pretrained model.  The pretrained model path may
             be a directory containing the model or a file path to the model itself.
            backend_type (BackendType): The type of backend to use. Defaults to BackendType.HTP.
            cache_root (Path, optional): The root directory for caching, if desired.

        Returns:
            GenAIBuilder: The created GenAIBuilder instance.

        Raises:
            ValueError: If the pretrained model path does not exist or if required files are missing.
        """

        if not os.path.exists(pretrained_model_path):
            raise ValueError(f"Pretrained model path '{pretrained_model_path}' does not exist")

        if backend_type not in _SUPPORTED_BACKENDS:
            raise ValueError(f"Backend type '{backend_type}' is not supported")

        if backend_type == BackendType.CPU:
            return GenAIBuilderCPU.from_pretrained(pretrained_model_path)

        apply_gguf_config = False
        if os.path.splitext(pretrained_model_path)[1] == ".gguf":
            gguf_artifacts_paths = gguf_builder.GGUFBuilder(pretrained_model_path).build_from_gguf()
            pretrained_model_path = gguf_artifacts_paths[0]
            apply_gguf_config = True

        # if pretrained_model_path is a file, get the path to the directory containing it
        pretrained_model_path_dir = pretrained_model_path
        if os.path.isfile(pretrained_model_path):
            pretrained_model_path_dir = os.path.dirname(pretrained_model_path)

        config = load_pretrained_config(pretrained_model_path_dir)
        builder = None
        if hasattr(config, "architectures"):
            if SupportedLLMs.LLAMA.value in config.architectures:
                builder = LlamaBuilderHTP.from_pretrained(os.fspath(pretrained_model_path), cache_root)
            elif SupportedLLMs.QWEN.value in config.architectures:
                builder = QwenBuilderHTP.from_pretrained(os.fspath(pretrained_model_path), cache_root)
            elif SupportedLLMs.PHI.value in config.architectures:
                builder = PhiBuilderHTP.from_pretrained(os.fspath(pretrained_model_path), cache_root)
            elif SupportedLLMs.MISTRAL.value in config.architectures:
                builder = MistralBuilderHTP.from_pretrained(os.fspath(pretrained_model_path), cache_root)
            elif SupportedLLMs.BAICHUAN.value in config.architectures:
                builder = BaichuanBuilderHTP.from_pretrained(os.fspath(pretrained_model_path), cache_root)
            elif SupportedLLMs.JAIS.value in config.architectures:
                builder = JaisBuilderHTP.from_pretrained(os.fspath(pretrained_model_path), cache_root)
            elif SupportedLLMs.PLAMO.value in config.architectures:
                builder = PlamoBuilderHTP.from_pretrained(os.fspath(pretrained_model_path), cache_root)
            elif SupportedLLMs.INDUS.value in config.architectures:
                builder = IndusBuilderHTP.from_pretrained(os.fspath(pretrained_model_path), cache_root)
        if not builder:
            logger.warning(
                "Architecture is unknown or unsupported; Returning default. "
                "This builder may work but will probably require additional configuration."
            )
            builder = GenAIBuilderHTP.from_pretrained(os.fspath(pretrained_model_path), cache_root)
        assert builder is not None
        if apply_gguf_config:
            calib_config = cast(CalibrationConfig, builder._calibration_config)
            calib_config.keep_weights_quantized = True
            setattr(calib_config, "float_bitwidth", 16)
            builder._transformation_config.model_transformer_config.arn_cl_options.auto_regression_number = [
                1
            ]
        return builder
