# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from qairt.api.configs.common import AISWBaseModel


class GenerationMetrics(AISWBaseModel):
    init_time: Optional[int] = None
    """Time to load the model, before prompt processing begins"""
    prompt_processing_time: Optional[int] = None
    """Microseconds between init and token generation, while processing the prompt."""
    prompt_processing_rate: Optional[float] = None
    """tokens per second.  Tokens in the prompt divided by prompt processing time"""
    token_generation_time: Optional[int] = None
    """Microseconds prompt processing and final response"""
    token_generation_rate: Optional[float] = None
    """tokens per second.  Tokens in the response divided by token generation time"""
    adapter_switch_time: Optional[int] = None
    """Microseconds to switch between LoRA adapaters."""

    def __str__(self):
        adapter_switch_str = (
            f"  Adapter Switch Time = {self.adapter_switch_time} us \n"
            if self.adapter_switch_time and self.adapter_switch_time != 0.0
            else ""
        )

        return (
            f"{'-' * 20} {'Metrics'}{'-' * 20} \n"
            f"Timing (microseconds): \n\n"
            f"  Init = {self.init_time or 0.0} us \n"
            f"  Prompt Processing Time = {self.prompt_processing_time or 0.0} us \n"
            f"  Token Generation Time = {self.token_generation_time or 0.0} us \n"
            f"{adapter_switch_str}\n"
            f"Tokens per second (toks/sec): \n\n"
            f"  Prompt Processing Rate = {self.prompt_processing_rate or 0} toks/sec \n"
            f"  Token Generation Rate = {self.token_generation_rate or 0} toks/sec \n"
        )


def parse_genie_profile_record(profile_record: Dict[str, Any]) -> GenerationMetrics:
    metrics = GenerationMetrics()
    if "components" in profile_record:
        dialog_components = [x for x in profile_record["components"] if x["type"] == "dialog"]
        if dialog_components:
            dialog_create_events = [
                x for x in dialog_components[0]["events"] if x["type"] == "GenieDialog_create"
            ]
            query_events = [x for x in dialog_components[0]["events"] if x["type"] == "GenieDialog_query"]

            if (
                dialog_create_events
                and "init-time" in dialog_create_events[0]
                and "value" in dialog_create_events[0]["init-time"]
            ):
                metrics.init_time = int(dialog_create_events[0]["init-time"]["value"])

            if query_events:
                query_event = query_events[-1]

                if (
                    "prompt-processing-rate" in query_event
                    and "value" in query_event["prompt-processing-rate"]
                ):
                    metrics.prompt_processing_rate = query_event["prompt-processing-rate"]["value"]
                    if "num-prompt-tokens" in query_event and "value" in query_event["num-prompt-tokens"]:
                        metrics.prompt_processing_time = int(
                            1000
                            * query_event["num-prompt-tokens"]["value"]
                            * query_event["prompt-processing-rate"]["value"]
                        )

                if "token-generation-time" in query_event and "value" in query_event["token-generation-time"]:
                    metrics.token_generation_time = int(query_event["token-generation-time"]["value"])

                if "token-generation-rate" in query_event and "value" in query_event["token-generation-rate"]:
                    metrics.token_generation_rate = query_event["token-generation-rate"]["value"]

            apply_lora_events = [
                x for x in dialog_components[0]["events"] if x["type"] == "GenieDialog_applyLora"
            ]
            if (
                apply_lora_events
                and "lora-adapter-switching-time" in apply_lora_events[0]
                and "value" in apply_lora_events[0]["lora-adapter-switching-time"]
            ):
                metrics.adapter_switch_time = apply_lora_events[0]["lora-adapter-switching-time"]["value"]

    return metrics


class GenerationExecutionResult(AISWBaseModel):
    output: str = ""
    """Raw output from text generation"""
    error: str = ""
    """Raw error response from text generation (empty on success)"""
    generated_text: str = ""
    """parsed response - the generated response (minus metrics)"""
    metrics: Optional[GenerationMetrics] = None
    """parsed metrics from the response."""


class GenAIExecutor(ABC):
    @abstractmethod
    def prepare_environment(self) -> "GenAIExecutor":
        """
        Prepares artifacts for execution on target
        """
        pass

    @abstractmethod
    def clean_environment(self) -> "GenAIExecutor":
        """
        Removes artifacts from target environment
        """
        pass
