# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import os
import tempfile
from enum import Enum
from os import PathLike
from typing import Optional

from qairt.api.configs.common import BackendType
from qairt.gen_ai_api.builders.gen_ai_builder import GenAIBuilder
from qairt.gen_ai_api.builders.gen_ai_utils import load_pretrained_config
from qairt.gen_ai_api.configs.gen_ai_config import GenAIConfig
from qairt.gen_ai_api.containers.gen_ai_container import GenAIContainer
from qairt.gen_ai_api.containers.llm_container import LLMContainer
from qairt.modules.genie_preparation.genie_preparation_module import GeniePreparationModule, QuantizationLevel


# TODO: The name of this object should be SupportedLLM or the one in GenAIBuilderFactory should be changed
class SupportedModel(Enum):
    """
    Enumeration of models that are supported for CPU execution
    """

    BAICHUAN_7B = "baichuan_7b_huggingface"
    LLAMA_2_13B = "llama-2_13b_huggingface"
    LLAMA2_7B = "llama-2-7b-huggingface"
    QWEN_7B = "qwen_7b_huggingface"
    LLAMA_V3_8B = "llama_v3_8b_huggingface"


class GenAIBuilderCPU(GenAIBuilder):
    """
    This class is responsible for building the Generative AI model for execution on the CPU backend.
    It is a subclass of GenAIBuilder and overrides the build method to return a GenAIContainer object.
    """

    # TODO: Maybe add caching support to CPU?
    def __init__(
        self,
        framework_model_path: PathLike | str,
        config: GenAIConfig,
    ):
        super().__init__(framework_model_path, config, BackendType.CPU)
        self._quantization_level: Optional[QuantizationLevel] = None
        self._working_directory = os.getenv("QAIRT_TMP_DIR", default=tempfile.gettempdir())

    @classmethod
    def from_pretrained(cls, pretrained_model_path: str | PathLike) -> "GenAIBuilderCPU":
        """
        Create a GenAIBuilderCPU object from a pretrained model path.

        Args:
            pretrained_model_path: The path to the pretrained model.
        Returns:
            The GenAIBuilder object.
        """
        config = load_pretrained_config(pretrained_model_path)

        gen_ai_config = cls._create_config_from_pretrained(config)
        builder = cls(pretrained_model_path, gen_ai_config)
        return builder

    def set_quantization_level(
        self, cpu_quantization_level: Optional[QuantizationLevel] = None
    ) -> "GenAIBuilderCPU":
        """
        Set the quantization level for the CPU backend.  This is used to determine the level of quantization
        to use when preparing the model for execution.

        Args:
            cpu_quantization_level: The quantization level to use when preparing the model for execution.
                Defaults to None, which will assume FP32.

        Returns:
            GenAIBuilderCPU: The GenAIBuilder object.
        """
        self._quantization_level = (
            None if cpu_quantization_level == QuantizationLevel.FP32 else cpu_quantization_level
        )
        return self

    def build(self) -> GenAIContainer:
        """
        Build the Generative AI model for execution on the CPU backend. This will
        return a Gen AI Container object, which has the configuration and artifacts necessary to run the
        model on the CPU backend.

        Args:
            path: The path to the directory where the artifacts will be stored.

        Examples:

            .. code-block:: python

                from qairt.gen_ai_api.gen_ai_builder_factory import GenAIBuilderFactory

                # Create a gen AI builder instance and set the appropriate targets
                gen_ai_builder: GenAIBuilder = GenAIBuilderFactory.create("model_dir", "CPU")
                gen_ai_builder.set_quantization_level(QuantizationLevel.Z4)

                # Build the model
                gen_ai_container = gen_ai_builder.build()

        Returns:
            GenAIContainer: The GenAIContainer with all artifacts prepared for execution.

        """

        model = GeniePreparationModule.prepare(
            self.framework_model_path,
            quantization_level=self._quantization_level,
            outfile=self._working_directory,
        )
        return LLMContainer([model], self.config, BackendType.CPU)
