# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import pathlib
from os import PathLike
from typing import Dict, Optional

from pydantic import field_validator

from qairt.api.configs.common import AISWBaseModel
from qairt.modules.genie_execution.genie_config import PositionalEncoding


class GenAIConfig(AISWBaseModel):
    """
    GenAIConfig holds common configuration information for the Generative AI Model, needed for Genie
    execution.  Common attributes (present in all subclasses):
    """

    tokenizer_path: str | PathLike
    """
    The path to the tokenizer.  Must point to an existing file.
    """
    context_length: int
    """context length"""

    n_vocab: int
    """The number of tokens in the vocabulary, which is also the first dimension of the embeddings matrix"""

    n_heads: Optional[int] = None
    """The number of attention heads used in the multi-head attention layers of the model"""

    n_layer: Optional[int] = None
    """The number of blocks in the model"""

    n_embd: Optional[int] = None
    """The hidden size of the model"""

    bos_token: int
    """The id of the beginning of stream token."""

    eos_token: int | list[int]
    """The id of the end of stream token."""

    eot_token: Optional[int] = None
    """The id of the end of turn token."""

    positional_encoding: Optional[PositionalEncoding] = None
    """An object describing the positional encodings"""

    kv_dim: Optional[int] = None
    """dimension of the kv cache"""

    rope_theta: Optional[float] = None
    """theta value for rotational positional encoding"""

    alpha_tensor_name: Optional[str] = ""
    """
    Name of the tensor where LoRA adapter is being applied.
    """

    adapter_count_by_use_case: Optional[Dict[str, int]] = {}
    """
    Dict of number of adapters per use case.
    """

    @field_validator("tokenizer_path")
    def validate_tokenizer_path(cls, v):
        path = pathlib.Path(v)
        if not path.resolve().is_file():
            raise FileNotFoundError(f"The tokenizer_path '{v}' does not point to an existing file.")
        return v
