# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import asyncio
import json
import os
import platform
import tempfile
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence

from typing_extensions import Self

from qairt import CompileConfig, CompiledModel
from qairt.api.configs.common import BackendType
from qairt.api.configs.device import Device, DevicePlatformType
from qairt.gen_ai_api.configs.gen_ai_config import GenAIConfig
from qairt.gen_ai_api.executors.gen_ai_executor import (
    GenAIExecutor,
    GenerationExecutionResult,
    GenerationMetrics,
    parse_genie_profile_record,
)
from qairt.modules.cache_module import CacheModule
from qairt.modules.dlc_module.dlc_module import DlcModule
from qairt.modules.genie_execution.genie_config import (
    Context,
    Dialog,
    DialogEngine,
    DialogType,
    EngineBackend,
    EngineBackendType,
    EngineModel,
    EngineModelType,
    GenieConfig,
    LoraConfig,
    LoraConfigAdapter,
    ModelBinary,
    ModelLibrary,
    QnnGenAiTransformerBackend,
    QnnHtpBackend,
    Sampler,
    Tokenizer,
)
from qairt.modules.genie_execution.genie_t2t_run_module import GenieT2TRunExecutionConfig, GenieT2TRunner
from qairt.modules.genie_execution.native_t2t_module import GenieNativeT2TRunner
from qairt.modules.lora.lora_config import UseCaseRunConfig
from qairt.utils import loggers


class T2TExecutor(GenAIExecutor):
    """
    The T2TExecutor handles text-to-text generation on target via Genie. It supports two modes of execution:

     - Native: This is the default mode of execution on platforms with native python support if no device is
       specified. Execution is performed via native python bindings.
     - Device: This is the mode of execution when a device is specified. Execution is performed via subprocess.
       See :class:`qairt.api.configs.device.Device` for supported device types.

    """

    _logger = loggers.get_logger(name=__name__)

    def __init__(
        self,
        models: List[CompiledModel],
        genai_config: GenAIConfig,
        backend: BackendType,
        device: Optional[Device] = None,
        backend_extensions_config: Optional[Dict] = None,
        qairt_sdk_root: Optional[str | os.PathLike] = None,
        clean_up: bool = True,
    ):
        """
        The executor has a compiled container and device to run it on. The GenAIExecutor will use the
        device to maintain an execution environment and run the compiled container

        Args:
            models (List[CompiledModel]): The compiled models to run
            genai_config (GenAIConfig): GenaiConfig to base GenieConfig on for genie-t2t-run execution
            backend (BackendType): Backend to use for genie-t2t-run execution.
            device (Device): Device to be used for execution. If set to none, native execution will be presumed.
            backend_extensions_config (Dict): Backend extensions configuration for genie-t2t-run execution
            qairt_sdk_root (str | os.PathLike): Path to QAIRT SDK with execution libraries (if different from installed QAIRT)
            clean_up (bool): Will delete artifacts pushed to device if True

        """
        self._clean_up = clean_up
        self._runner: GenieT2TRunner | None = None
        self._native_runner: GenieNativeT2TRunner | None = None

        self._environment_prepared = False
        self._work_dir: tempfile.TemporaryDirectory | None = None

        self._models: List[CompiledModel] = models
        if not self._models:
            raise ValueError("No models provided for execution")
        self._gen_ai_config = genai_config
        self._device: Optional[Device] = device
        self._backend: BackendType = backend
        self._backend_extensions_config: Optional[Dict] = backend_extensions_config
        self._qairt_sdk_root: str | os.PathLike = qairt_sdk_root or str(os.environ.get("QNN_SDK_ROOT", ""))

        if self._backend_extensions_config and self._backend == BackendType.CPU:
            self._logger.warning(
                f"Ignoring provided backend extensions as requested backend, {self._backend}, does not support extensions"
            )

    def prepare_environment(self) -> Self:
        """
        Prepares artifacts for execution on target
        """
        if not self._environment_prepared:
            self._work_dir = tempfile.TemporaryDirectory()

            backend_extensions_path = ""
            if self._backend_extensions_config:
                backend_extensions_path = os.path.join(self._work_dir.name, "backend_extensions.json")
                with open(backend_extensions_path, "w") as f:
                    f.write(json.dumps(self._backend_extensions_config, indent=2))

            genie_config = self._build_genie_config(backend_extensions_path)

            if self._is_native_execution():
                self._native_runner = GenieNativeT2TRunner(genie_config=genie_config)

            else:
                assert isinstance(self._device, Device)  # for mypy

                self._runner = GenieT2TRunner(
                    genie_config,
                    self._backend,
                    self._device.info,
                    self._qairt_sdk_root,
                    clean_up=self._clean_up,
                )
                self._runner.load()
        self._environment_prepared = True
        return self

    def clean_environment(self) -> Self:
        """
        Removes artifacts from target environment
        """
        self._logger.info("Cleaning environment")
        if self._runner:
            self._runner.unload()
            self._runner = None

        if self._work_dir:
            self._work_dir.cleanup()

        if self._native_runner:
            del self._native_runner
            self._native_runner = None

        self._environment_prepared = False
        return self

    def generate(
        self, prompt: str, *, lora_config: Optional[UseCaseRunConfig] = None
    ) -> GenerationExecutionResult:
        """
        Generates a response from a given prompt.

        Args:
            prompt (str): The prompt to be used for generation.
            lora_config (UseCaseRunConfig): Configuration used to control how LoRA adapters are applied during model inference.

        Returns:
            GenerationExecutionResult: The result of the generation containing the output text and associated
            generation metrics.
        """
        out = GenerationExecutionResult()

        if not self._environment_prepared:
            self.prepare_environment()

        if self._is_native_execution() and self._native_runner:
            result = self._native_runner.query(prompt, lora_config=lora_config)
            self._native_runner.reset_dialog()
            return result
        else:
            return self._generate_non_native(prompt, lora_config)

    def stream_generate(self, prompt: str, q: asyncio.Queue) -> asyncio.Task:
        """
        Starts streaming generation and returns the task that will produce the final result.

        Args:
            prompt (str): The prompt to be used for generation.
            q (asyncio.Queue): An asyncio queue used to stream output chunks back to the caller.

        Returns:
            asyncio.Task: A task that will eventually return GenerationExecutionResult.
        """
        if not self._environment_prepared:
            self.prepare_environment()

        if self._is_native_execution() and self._native_runner:
            try:
                return asyncio.create_task(self._native_runner.stream_query(prompt, q))
            except Exception as e:
                self._logger.warning(f"Streaming failed to start: {e}")
                raise
        else:
            raise NotImplementedError("Non-native execution is not supported.")

    def _generate_non_native(
        self, prompt: str, lora_config: Optional[UseCaseRunConfig] = None
    ) -> GenerationExecutionResult:
        """
        Generates a response from a given prompt using the GenieT2TRunner. A device must be specified
        for this method to be called.

        Args:
            prompt (str): The prompt to be used for generation.
            lora_config (UseCaseRunConfig): Configuration used to control how LoRA adapters are applied during model inference.

        Returns:
            GenerationExecutionResult: The result of the generation containing the output text and associated
            generation metrics.
        """
        out = GenerationExecutionResult()

        if not self._runner:
            raise RuntimeError("Environment preparation failed")
        else:
            t2t_result = self._runner.run(GenieT2TRunExecutionConfig(prompt=prompt, lora_config=lora_config))

            out.output = t2t_result.stdout

            if t2t_result.return_code != 0:
                out.error = t2t_result.stderr
            else:
                out.output += "\n" + t2t_result.stderr
            self._logger.debug(f"stdout: {t2t_result.stdout}\nstderr:{t2t_result.stderr}")

            begin_idx = out.output.find("[BEGIN]:")
            end_idx = out.output.find("[END]", begin_idx)
            if begin_idx > -1 and end_idx > begin_idx:
                out.generated_text = out.output[begin_idx + len("[BEGIN]:") : end_idx].strip()

            if t2t_result.profile_record:
                out.metrics = parse_genie_profile_record(t2t_result.profile_record)

        return out

    def _is_native_execution(self) -> bool:
        if not self._device:
            return True
        elif self._device.identifier is None:
            if (
                platform.system() == "Linux"
                and platform.machine() == "x86_64"
                and self._device.type == DevicePlatformType.X86_64_LINUX
            ):
                return True

            if platform.system() == "Windows":
                if (
                    "AMD64" in platform.processor() or "Intel64" in platform.processor()
                ) and self._device.type == DevicePlatformType.X86_64_WINDOWS_MSVC:
                    return True
                elif (
                    "ARM64" in platform.processor()
                    or "AARCH64" in platform.processor()
                    or "ARMv8" in platform.processor()
                ) and self._device.type == DevicePlatformType.WOS:
                    return True

        return False

    def _build_genie_config(
        self,
        backend_extensions_path: Optional[str | os.PathLike] = None,
    ) -> GenieConfig:
        if not self._models:
            raise RuntimeError("No models were loaded into the container. Cannot build Genie config.")

        model_paths: List[str] = []
        lora_use_cases = defaultdict(list)
        # adding empty string "bin path" to correspond to every context binary that doesn't have adapter info.
        should_collect_lora = any(
            m.lora_use_case_binary_map for m in self._models if isinstance(m.module, CacheModule)
        )
        all_unique_uc_names = set()
        if should_collect_lora:
            for m in self._models:
                if m.lora_use_case_binary_map:
                    for uc_name in m.lora_use_case_binary_map.keys():
                        if uc_name != "base":
                            all_unique_uc_names.add(uc_name)

        for model in self._models:
            if isinstance(model.module, DlcModule):
                caches = list(model.module.caches.values())
                if len(caches) > 1:
                    raise ValueError("Expected only a single context binary cache per dlc")
                model_paths.append(str(caches[0].path))

            elif isinstance(model.module, CacheModule):
                model_paths.append(str(model.module.path))
            if should_collect_lora:
                for uc_name in all_unique_uc_names:
                    lora_use_cases[uc_name].append(model.lora_use_case_binary_map.get(uc_name, ""))

        # Build LoRA config
        lora_config = None
        if lora_use_cases:
            lora_config_adapters = []
            for uc_name, bin_paths in lora_use_cases.items():
                # Collect LoRA adapter alphas
                if (adapter_count_dict := self._gen_ai_config.adapter_count_by_use_case) is not None:
                    if uc_name == "default_adapter":
                        alphas = []
                    elif adapter_count_dict.get(uc_name) is not None:
                        alphas = [f"alpha{i}" for i in range(adapter_count_dict[uc_name])]
                    else:
                        raise KeyError(
                            f"The use case {uc_name} cannot be found in the adapter_count_by_use_case dict."
                        )
                    lora_config_adapters.append(
                        LoraConfigAdapter(name=uc_name, alphas=alphas, bin_sections=bin_paths)
                    )

            lora_config = LoraConfig(
                alpha_tensor_name=self._gen_ai_config.alpha_tensor_name,
                adapters=lora_config_adapters,
            )

        if self._backend == BackendType.HTP:
            engine_backend = EngineBackend(
                type=EngineBackendType.QNN_HTP,
                QnnHtp=QnnHtpBackend(
                    **{
                        "poll": True,
                        "use-mmap": not self._is_native_execution(),
                        "spill-fill-bufsize": 0,
                        "mmap-budget": 40,
                        "kv-dim": self._gen_ai_config.kv_dim,
                    }
                ),
            )
            if backend_extensions_path:
                engine_backend.extensions = backend_extensions_path
            engine_model = EngineModel(
                type=EngineModelType.BINARY,
                binary=ModelBinary(ctx_bins=model_paths, lora=lora_config),
                positional_encoding=self._gen_ai_config.positional_encoding,
            )
        elif self._backend == BackendType.CPU:
            engine_backend = EngineBackend(
                type=EngineBackendType.QNN_GEN_AI_TRANSFORMER,
                QnnGenAiTransformer=QnnGenAiTransformerBackend(
                    n_layer=self._gen_ai_config.n_layer,
                    n_embd=self._gen_ai_config.n_embd,
                    n_heads=self._gen_ai_config.n_heads,
                ),
            )
            engine_model = EngineModel(
                type=EngineModelType.LIBRARY, library=ModelLibrary(model_bin=model_paths[0], lora=lora_config)
            )

        else:
            raise RuntimeError(f"Unsupported backend: {self._backend}")

        return GenieConfig(
            dialog=Dialog(
                type=DialogType.BASIC,
                context=Context(
                    size=self._gen_ai_config.context_length,
                    n_vocab=self._gen_ai_config.n_vocab,
                    bos_token=self._gen_ai_config.bos_token,
                    eos_token=self._gen_ai_config.eos_token,
                    eot_token=self._gen_ai_config.eot_token,
                ),
                sampler=Sampler(seed=42, temp=1.2, top_k=20, top_p=0.75),
                tokenizer=Tokenizer(path=self._gen_ai_config.tokenizer_path),
                engine=DialogEngine(backend=engine_backend, model=engine_model),
            )
        )

    def __enter__(self) -> Self:
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._clean_up:
            self.clean_environment()
        return False

    def __new__(cls, *args, **kwargs):
        instance = super().__new__(cls)
        instance._clean_up = False
        return instance

    def __del__(self):
        if self._clean_up:
            self.clean_environment()
