# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================


import os
from typing import Optional

import onnx

from qairt.api.configs.common import BackendType
from qairt.api.transforms.model_transformer_config import (
    ModelTransformerConfig,
    QuantizationStage,
    SplitModelConfig,
)
from qairt.utils.loggers import get_logger
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.onnx_model import OnnxModel


def transform(
    model: str | os.PathLike | onnx.ModelProto,
    backend: BackendType = BackendType.HTP,
    quantization_stage: QuantizationStage | None = None,
    encodings: str | os.PathLike | None = None,
    **extra_args,
) -> list[OnnxModel]:
    """
    Transforms an ONNX model by performing MHA2SHA (MultiHead Attention to Single Head Attention)
    and/or model splitting.

    Args:
        model (str | PathLike | onnx.ModelProto): The ONNX model to be transformed, either as a
            file path or a ModelProto object.
        backend (Optional[BackendType]): The backend type for which the model is being transformed.
            Defaults to BackendType.HTP.
        quantization_stage (Optional[QuantizationStage]): The quantization stage for the
            transformation. Defaults to QuantizationStage.POST_QUANT. If post_quant is passed,
            then MHA2SHA and model splitting will be performed.
        encodings (Optional[str | PathLike]): Encodings to be used for the
            transformation. Can be a str or a file path. Defaults to None.

        extra_args:
            lora_adapters_path: path to a yaml file declaring the LoRA adapters (use cases).
            lora_tensor_names_path: path to a text file containing the LoRA tensor names.
            split_model: SplitModelConfig or serialized (as dictionary),
                See :class:`qairt.api.transformer.model_transformer_config` for more details.
            mha_config: MhaConfig or serialized (as dictionary).  If mha_config is omitted,
                transform will use the legacy (v1) MHA transformation algorithm.
                See :class:`qairt.api.transformer.model_transformer_config` for more details.

    Examples:
        .. code-block:: python

        import qairt
        fw_model = "path/to/model"
        transformed_model = transform(fw_model,
                                      backend=BackendType.HTP,
                                      quantization_stage=QuantizationStage.POST_QUANT)

    Returns:
        list[OnnxModel]: An list of OnnxModel objects containing the split/transformed model and encodings.
    """

    _transform_logger = get_logger("qairt.transform")
    lora_adapters_path = extra_args.pop("lora_adapters_path", None)
    lora_tensor_names_path = extra_args.pop("lora_tensor_names_path", None)
    # Parse relevant kwargs from extra_args using ModelTransformerConfig
    config = ModelTransformerConfig.from_dict(extra_args)

    if isinstance(model, onnx.ModelProto) and isinstance(encodings, dict):
        onnx_model = OnnxModel(
            model=model,
            encodings=encodings,
            lora_adapters=lora_adapters_path,
            lora_tensor_names=lora_tensor_names_path,
        )
    else:
        onnx_model = OnnxModel.load(
            model_path=model,
            encodings_path=encodings,
            lora_adapters_path=lora_adapters_path,
            lora_tensor_names_path=lora_tensor_names_path,
        )

    if quantization_stage is None:
        _transform_logger.warning("`quantization_stage` not set. Defaulting to POST_QUANT")
        quantization_stage = QuantizationStage.POST_QUANT

    # Apply transformations based on backend and quantization stage
    match backend:
        case BackendType.HTP:
            match quantization_stage:
                case QuantizationStage.POST_QUANT:
                    splits = onnx_model.split(
                        num_splits=config.split_model.num_splits,
                        split_embedding=config.split_model.split_embedding,
                        split_lm_head=config.split_model.split_lm_head,
                    )
                    for _split in splits:
                        if config.mha_config:
                            _split.mha2sha_v2(**config.mha_config.__dict__)
                        else:
                            _split.mha2sha_v1()
                    return splits
                case QuantizationStage.PRE_QUANT:
                    raise NotImplementedError(
                        "Pre-quantization transformations are not currently supported through this API."
                    )
                case _:
                    raise ValueError(
                        f"Invalid value for quantization_stage: {quantization_stage}. Expected one of: QuantizationStage.PRE_QUANT or QuantizationStage.POST_QUANT"
                    )
        case _:
            raise NotImplementedError(f"Backend type {backend} not supported")
