# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import json
import os
import shutil
from os import PathLike
from pathlib import Path
from typing import Dict, List, Optional, Union

from qairt.api.compiled_model import CompileConfig, CompiledModel
from qairt.api.configs.common import AISWBaseModel, BackendType
from qairt.api.configs.device import Device
from qairt.gen_ai_api.configs.gen_ai_config import GenAIConfig
from qairt.gen_ai_api.containers.gen_ai_container import GenAIContainer
from qairt.gen_ai_api.executors.t2t_executor import T2TExecutor
from qairt.modules.cache_module import CacheModule
from qairt.modules.dlc_module import DlcModule
from qairt.utils import loggers


class ContainerMetadata(AISWBaseModel):
    """
    Container metadata for serialization/deserialization
    """

    backend: BackendType


class LLMContainer(GenAIContainer):
    """
    Produced by an GenAiBuilder and consumed by an GenAiExecutor
    """

    _logger = loggers.get_logger(name=__name__)

    def __init__(
        self,
        models: List[CompiledModel],
        gen_ai_config: GenAIConfig,
        backend: BackendType,
        *,
        backend_extensions_config: Optional[Dict] = None,
        compile_config: Optional[CompileConfig] = None,
    ):
        """
        Create a GenAIContainer

        Args:
            models (CompiledModel): List of CompiledModels containing the LLM split(s) prepared for QAIRT execution
            gen_ai_config (GenAIConfig): contains the configuration metadata for the GenAI model.
            backend (BackendType): The backend the artifacts in this container were prepared for
            backend_extensions_config (Optional[Dict]): Backend extensions configuration for execution
            compile_config (Optional[CompileConfig]): contains configuration metadata for the compiled configuration. Defaults to None.
        """
        if not models:
            raise ValueError("LLMContainer must be have at least one CompiledModel")
        self._models: List[CompiledModel] = models
        self._gen_ai_config: GenAIConfig = gen_ai_config
        self._backend: BackendType = backend
        self._backend_extensions_config: Dict | None = backend_extensions_config

        if compile_config and backend_extensions_config:
            self._logger.warning(
                "Both a CompileConfig and backend extensions config were provided. Ignoring CompileConfig."
            )

        elif isinstance(compile_config, CompileConfig):
            self._backend_extensions_config = compile_config.model_dump(context={"backend_extensions": True})

    @classmethod
    def _metadata_file(cls, path: str | PathLike) -> str:
        return os.path.join(path, "metadata.json")

    @classmethod
    def _gen_ai_config_file(cls, path: str | PathLike) -> str:
        return os.path.join(path, "gen_ai_config.json")

    @classmethod
    def _backend_ext_file(cls, path: str | PathLike) -> str:
        return os.path.join(path, "backend_extensions.json")

    @classmethod
    def _tokenizer_file(cls, path: str | PathLike) -> str:
        return os.path.join(path, "tokenizer.json")

    @classmethod
    def _model_dir(cls, path: str | PathLike) -> str:
        return os.path.join(path, "models")

    @classmethod
    def _create_split_dir(cls, path: Union[str, PathLike], index: int) -> str:
        split_dir = cls._split_dir(path, index)
        os.makedirs(split_dir, exist_ok=True)
        return split_dir

    @classmethod
    def _split_dir(cls, path: Union[str, PathLike], index: int) -> str:
        return os.path.join(cls._model_dir(path), f"split_{index}")

    @classmethod
    def _model_dlc(cls, path: str | PathLike, index: int) -> str:
        return os.path.join(cls._split_dir(path, index), f"model.dlc")

    @classmethod
    def _model_ctx_bin(cls, path: str | PathLike, index: int) -> str:
        return os.path.join(cls._split_dir(path, index), f"model.bin")

    @classmethod
    def _lora_dir(cls, path: str | PathLike) -> str:
        return os.path.join(path, "lora")

    @classmethod
    def _lora_bin(cls, path: str | PathLike, use_case_name: str, index: int) -> str:
        return os.path.join(cls._split_dir(path, index), f"{use_case_name}.bin")

    def write_use_cases_json(self, dest: str | os.PathLike):
        """
        Extracts all unique use case names from the models' lora_use_case_binary_map
        and writes them to a JSON file named 'use_cases.json' inside the models directory.

        Args:
            dest (str | PathLike): Base destination path where the models directory resides.
        """
        use_cases: set[str] = set()
        for model in self._models:
            if hasattr(model, "lora_use_case_binary_map"):
                use_cases.update(model.lora_use_case_binary_map.keys())

        models_dir = os.path.join(dest, "models")
        os.makedirs(models_dir, exist_ok=True)

        use_cases_path = os.path.join(models_dir, "use_cases.json")
        with open(use_cases_path, "w") as f:
            json.dump({"use_cases": sorted(use_cases)}, f, indent=2)

    def save(self, dest: str | PathLike, *, exist_ok: bool = False):
        """
        Save all artifacts to disk.  Note, this will copy artifacts into the destination directory, and update any
        configurations accordingly.

        Args:
            dest (str | PathLike): Path to save the artifacts
        """
        if os.path.exists(dest):
            if not os.path.isdir(dest):
                raise NotADirectoryError(f"Destination path {dest} exists but is not a directory")
            elif not exist_ok:
                raise ValueError(
                    f"Destination path {dest} already exists.  To use an existing directory, specify exist_ok=True"
                )
        else:
            os.makedirs(dest, exist_ok=exist_ok)

        # copy the tokenizer found at self._gen_ai_config.tokenizer_path to dest
        shutil.copyfile(self._gen_ai_config.tokenizer_path, self._tokenizer_file(dest))

        if self._models:
            os.makedirs(self._model_dir(dest), exist_ok=True)

            for i, model in enumerate(self._models):
                self._create_split_dir(dest, i)

                if isinstance(model.module, CacheModule):
                    model.module.save(self._model_ctx_bin(dest, i))

                    if model.lora_use_case_binary_map:
                        for use_case_name, path in model.lora_use_case_binary_map.items():
                            if use_case_name == "base":
                                continue  # Already saved
                            shutil.copy2(path, self._lora_bin(dest, use_case_name, i))

                elif isinstance(model.module, DlcModule):
                    model.module.save(self._model_dlc(dest, i))
                else:
                    raise TypeError(f"Unsupported Compiled model module type {type(model.module)}")

            self.write_use_cases_json(dest)

        with open(self._metadata_file(dest), "w") as f:
            f.write(ContainerMetadata(backend=self._backend).model_dump_json())

        with open(self._gen_ai_config_file(dest), "w") as f:
            f.write(self._gen_ai_config.model_dump_json(by_alias=True, exclude_none=True, indent=2))

        if self._backend_extensions_config:
            with open(self._backend_ext_file(dest), "w") as f:
                f.write(json.dumps(self._backend_extensions_config, indent=2))

    @classmethod
    def _load_lora_binaries(
        cls, path: Union[str, PathLike], index: int, use_cases: List[str], base_filename: str
    ) -> dict:
        """
        Load LoRA binaries from the split directory for a given model index.

        Args:
            path (str | PathLike): Base path to the container.
            index (int): Model index.
            base_filename (str): Name of the base model file to exclude.

        Returns:
            dict: Mapping of use_case_name to Path of LoRA binary.
        """
        lora_map = {}
        split_dir = Path(cls._split_dir(path, index))

        for use_case_name in use_cases:
            if use_case_name == "base":
                continue
            candidate = split_dir / f"{use_case_name}.bin"
            if candidate.exists():
                lora_map[use_case_name] = candidate

        return lora_map

    @classmethod
    def load(cls, path: str | PathLike) -> "LLMContainer":
        """
        Load LLMContainer assets from disk

        Args:
            path (str | PathLike): Path to load the artifacts from.  This should be a directory that is produced
            from a previous call to LLMContainer.save().

        Returns:
            LLMContainer: LLMContainer instance with the loaded artifacts
        """
        if not os.path.isdir(path):
            raise NotADirectoryError(
                f"Path {path} is not a directory.  This should be a directory containing (minimally) a config file "
                f"{cls._gen_ai_config_file('')}, a tokenizer file ({cls._tokenizer_file('')}), and a model directory "
                f"({cls._model_dir('')}). There may be other files and directories as well."
            )

        backend = None
        with open(cls._metadata_file(path), "r") as f:
            backend = ContainerMetadata(**json.load(f)).backend

        gen_ai_config = None
        with open(cls._gen_ai_config_file(path), "r") as f:
            gen_ai_config = GenAIConfig(**json.load(f))

        # Update tokenizer path to loaded location
        gen_ai_config.tokenizer_path = Path(cls._tokenizer_file(path)).resolve()
        if not os.path.exists(gen_ai_config.tokenizer_path):
            raise FileNotFoundError(f"Tokenizer file not found at location: {gen_ai_config.tokenizer_path}")

        backend_extensions_config = None
        if os.path.exists(cls._backend_ext_file(path)):
            with open(cls._backend_ext_file(path), "r") as f:
                backend_extensions_config = json.load(f)

        models = []
        if not os.path.isdir(cls._model_dir(path)):
            raise NotADirectoryError(f"Serialized model directory does not exist: {cls._model_dir(path)}")

        i = 0
        use_cases_path = os.path.join(cls._model_dir(path), "use_cases.json")
        try:
            with open(use_cases_path, "r") as f:
                use_cases = json.load(f)["use_cases"]
        except FileNotFoundError:
            use_cases = []

        while os.path.isfile(cls._model_ctx_bin(path, i)) or os.path.isfile(cls._model_dlc(path, i)):
            model = None
            lora_use_case_binary_map = {}

            # Load base model
            if os.path.isfile(cls._model_ctx_bin(path, i)):
                base_path = Path(cls._model_ctx_bin(path, i))
                model = CompiledModel(CacheModule.load(path=base_path))
                lora_use_case_binary_map["base"] = base_path

                lora_use_case_binary_map.update(cls._load_lora_binaries(path, i, use_cases, f"model_{i}.bin"))

            elif os.path.isfile(cls._model_dlc(path, i)):
                base_path = Path(cls._model_dlc(path, i))
                model = CompiledModel(DlcModule.load(path=base_path))
                lora_use_case_binary_map["base"] = base_path
                lora_use_case_binary_map.update(cls._load_lora_binaries(path, i, use_cases, f"model_{i}.dlc"))

            if model is not None:
                model.lora_use_case_binary_map = lora_use_case_binary_map
                models.append(model)

            i += 1

        return LLMContainer(
            models, gen_ai_config, backend, backend_extensions_config=backend_extensions_config
        )

    def get_executor(self, device: Optional[Device] = None, clean_up: bool = True, **kwargs) -> T2TExecutor:
        if not self._models:
            raise ValueError("No models were loaded into the container. Nothing to execute.")

        return T2TExecutor(
            self._models,
            self._gen_ai_config,
            self._backend,
            device=device,
            backend_extensions_config=self._backend_extensions_config,
            clean_up=clean_up,
        )
