# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import json
import os
import pathlib
import re
from types import SimpleNamespace
from typing import List, Tuple

import onnx

from qairt.api.converter.converter_config import InputTensorConfig
from qairt.modules.genie_execution.genie_config import PositionalEncodingType
from qairt.utils import loggers

_logger = loggers.get_logger(name=__name__)


def get_input_layout(onnxfile) -> List[InputTensorConfig]:
    """
    Retrieves the input layout configuration from an ONNX model file. The layout is set to "NONTRIVIAL" for
    all inputs, and the name and datatype are extracted from the ONNX model.

    Args:
        onnxfile (str): The path to the ONNX model file.

    Returns:
        List[InputTensorConfig]: A list of InputTensorConfig objects, each representing an input tensor
        in the ONNX model.

    """
    onnxmodel = onnx.load(onnxfile, load_external_data=False)
    input_info = [
        InputTensorConfig(
            layout="NONTRIVIAL",
            name=i.name,
            datatype=onnx.TensorProto.DataType.Name(i.type.tensor_type.elem_type).lower(),
        )
        for i in onnxmodel.graph.input
    ]

    return input_info


def get_kv_dim(onnxmodel: onnx.ModelProto) -> int:
    for x in onnxmodel.graph.input:
        if re.match(r"past_value_\d+_in", x.name):
            return x.type.tensor_type.shape.dim[-1].dim_value
    raise KeyError("past_value_[n]_in not found in graph.")


def get_pos_id_dim(onnxmodel: onnx.ModelProto) -> int:
    for x in onnxmodel.graph.input:
        if x.name == "position_ids_sin":
            return x.type.tensor_type.shape.dim[-1].dim_value
    raise KeyError("position_ids_sin not found in the graph.")


def get_positional_encodings_type(onnxmodel: onnx.ModelProto, ar: int, cl: int) -> PositionalEncodingType:
    for x in onnxmodel.graph.input:
        if x.name == "position_ids_sin":
            return PositionalEncodingType.ROPE
        if x.name == "position_ids":
            # if it is 1 x ARn == absolute
            if (
                x.type.tensor_type.shape.dim[0].dim_value == 1
                and x.type.tensor_type.shape.dim[1].dim_value == ar
            ):
                return PositionalEncodingType.ABSOLUTE
            # if it is ARn x CL it is alibi
            if (
                x.type.tensor_type.shape.dim[0].dim_value == ar
                and x.type.tensor_type.shape.dim[1].dim_value == cl
            ):
                return PositionalEncodingType.ALIBI
    raise KeyError("Unable to detect/determine positional encoding in graph")


def get_tensor_values(onnxmodel: onnx.ModelProto) -> Tuple[int, int]:
    """
    Extracts the tensor values from an ONNX model, specifically the dimensions of the attention mask.
    This function assumes that the attention mask is the second input in the ONNX graph, or that it can be found by name.
    The function also assumes that the other dimensions of the attention mask are always 1.

    Args:
        onnxmodel (onnx.ModelProto): The onnx model to inspect

    Returns:
        Tuple[int, int]: A tuple containing the context length and the ARn of the attention mask.

    Raises:
        ValueError: If the input layers do not conform to the expected format and the attention mask is not found.
    """
    attention_mask = onnxmodel.graph.input[1]
    if attention_mask.name != "attention_mask":
        _logger.warning("Attention mask not found in expected location of input layers. Searching by name. ")
        attention_mask = next(
            (input for input in onnxmodel.graph.input if input.name == "attention_mask"), None
        )
        if not attention_mask:
            raise ValueError(
                "input layers do not conform to expected format: attention_mask not found in input layers"
            )
    dims = [dim.dim_value for dim in attention_mask.type.tensor_type.shape.dim]
    if len(dims) < 2:
        raise ValueError("Attention mask does not contain at least 2 dimensions")
    dims.sort()
    # ARn must be less than or equal to context length
    # The other dimensions are always 1 (?)
    _logger.debug(f"Extracted dims: {dims} from attention mask.  Presuming AR={dims[-2]} and CL={dims[-1]}")
    return (dims[-1], dims[-2])


def load_pretrained_config(pretrained_model_path_dir: str | os.PathLike) -> SimpleNamespace:
    """
    Loads the configuration of a pre-trained model from a JSON file.
    The configuration is loaded from a file named "config.json" within the specified directory.
    The loaded configuration is augmented with a new attribute `_name_or_path` set to the provided `pretrained_model_path_dir`
    to be compliant with AutoConfig.from_pretrained behavior.

    Args:
        pretrained_model_path_dir (str | os.PathLike): The directory path of the pre-trained model.

    Returns:
        SimpleNamespace: The loaded configuration as a SimpleNamespace object.

    """
    with open(pathlib.Path(pretrained_model_path_dir) / "config.json", "r") as f:
        config = json.load(f, object_hook=lambda d: SimpleNamespace(**d))
        config._name_or_path = pretrained_model_path_dir
    return config


def count_parameters(onnx_model_path):
    model = onnx.load(onnx_model_path)
    total_parameters = 0
    for initializer in model.graph.initializer:
        initializer_size = 1
        for dim in initializer.dims:
            initializer_size *= dim
        total_parameters += initializer_size
    return total_parameters
