# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
Entry of the mha2sha v2 module
"""
import os
from typing import Dict

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .optimizer import GraphOptimizer
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.encodings import (
    serialize_graph_encodings,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.logger import logger
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .validation.ort_accuracy_checker import (
    verify_onnx_with_inputs_list,
    verify_onnx_with_random_inputs,
)


def apply_mha2sha_optimization_in_memory(onnx_proto,
                                         named_encodings,
                                         named_safetensors,
                                         updatable_tensors,
                                         ws: str,
                                         extract_lorav2_alpha=False,
                                         permute_kv_cache_io=False,
                                         key_cache_name_pattern=None,
                                         value_cache_name_pattern=None,
                                         m2s_head_split_map=None,
                                         base_dir=None,
                                         enable_validation:bool=False,
                                         input_raw_list_path:str|None=None,
                                         input_raw_base_dir:str|None=None,
                                         onnx_file_path:str|None=None,
                                         ) -> Dict:
    # pylint: disable=[R0913,R0917,R0914]
    """
    Apply mha2sha optimization on the data in memory.

    Args:
        onnx_proto (onnx.ModelProto): The ONNX model in memory.
        named_encodings (dict): Dictionary of named encodings.
        named_safetensors (dict): Dictionary of named safetensors.
        updatable_tensors (list): List of updatable tensor names.
        ws (str): Workspace to store temporary files.
        extract_lorav2_alpha (bool, optional): Whether to extract LoRAv2 alpha values. Defaults to False.
        permute_kv_cache_io (bool, optional): Whether to permute key-value cache inputs/outputs. 
                                              Defaults to False.
        key_cache_name_pattern (str, optional): Pattern for key cache tensor names. Defaults to None.
        value_cache_name_pattern (str, optional): Pattern for value cache tensor names. Defaults to None.
        m2s_head_split_map (dict, optional): Mapping for splitting multi-head attention to single-head 
                                             attention. Defaults to None.
        base_dir (str, optional): Base directory for the ONNX model. Defaults to None.
        enable_validation (bool, optional): Whether to verify the generated ONNX model with ONNX Runtime. 
                                      Defaults to False.
        input_raw_list_path (str, optional): Path of raw input list for verification. (same format for 
                                        qairt-quantizer). Defaults to None. If not providen, random 
                                        inputs will be used for verification.
        input_raw_base_dir (str, optional): Base directory for raw input files. Defaults to None.
        onnx_file_path (str, optional): Path of the MHA ONNX file. Defaults to None.

    Returns:
        dict: A dictionary containing the optimizer, named encodings, named safetensors, 
              updatable tensors, tracing info, merged tracing info, and special inputs.
    """
    assert os.path.isdir(ws), f"Workspace {ws} does not exist or is not a directory"

    optimizer = GraphOptimizer(
        onnx_proto=onnx_proto,
        named_encodings=named_encodings,
        named_safetensors=named_safetensors,
        updatable_tensors=updatable_tensors,
        base_dir=base_dir,
    )

    if enable_validation and onnx_file_path is None:
        # onnx_file_path is not provided,
        # we firstly save the mha onnx, so that it can be used for comparison later
        onnx_file_path = os.path.join(ws, "mha/model.onnx")
        optimizer.save_onnx(onnx_file_path)

    out_onnx_special_inputs = {}

    if extract_lorav2_alpha:
        lora_alpha_np_value = optimizer.extract_lora_v2_alpha()
        out_onnx_special_inputs["lora_alpha"] = lora_alpha_np_value

    optimized_inputs_preprocs = []
    optimized_outputs_postprocs = []

    if permute_kv_cache_io:
        kv_cache_io_permutor = optimizer.permute_kv_cache_io(
            key_cache_name_pattern,
            value_cache_name_pattern
        )
        optimized_inputs_preprocs.append(
            kv_cache_io_permutor.preproc_inputs)
        optimized_outputs_postprocs.append(
            kv_cache_io_permutor.postproc_outputs)

    optimizer.apply_mha2sha(m2s_head_split_map)

    out_info = {
        # TODO: have to load weights entirely to fit the OnnxModel GraphManager api. pylint: disable=[W0511]
        "onnx_proto": optimizer.get_onnx_proto(load_weights=True),
        # but we actually don't need to do this expensive operation
        # if the api is in the form of onnxscript,
        # or in the form of onnx file (rather than proto in memory)
        "optimizer": optimizer,  # if possible, call optimizer.save_onnx(),
        # which is more memory efficient than
        # onnx.save(optimizer.get_onnx_proto(load_weights=True))
        "named_encodings": {k: serialize_graph_encodings(v) for k, v in
                            optimizer.get_encodings().items()},
        "named_safetensors": optimizer.get_safetensors(),
        "updatable_tensors": optimizer.get_updatable_tensor_names(),
        "tracing_info": optimizer.get_tracing_info(),
        "merged_tracing_info": optimizer.get_tracing_info(merged=True),
        "special_inputs": out_onnx_special_inputs,
    }

    if enable_validation:
        assert onnx_file_path is not None  # check for mypy, definitely true
        tmp_out_onnx_path = os.path.join(ws, "sha", os.path.basename(onnx_file_path))
        optimizer.save_onnx(tmp_out_onnx_path)

        if input_raw_list_path:
            logger.info(
                "input_raw_list_path is specified, validating onnx with onnxruntime and the given inputs")
            verify_onnx_with_inputs_list(onnx_file_path, tmp_out_onnx_path,
                                            input_raw_list_path, input_raw_base_dir,
                                            optimized_onnx_special_inputs=out_onnx_special_inputs,
                                            optimized_inputs_preprocs=optimized_inputs_preprocs,
                                            optimized_outputs_postprocs=optimized_outputs_postprocs
                                            )
        else:
            logger.info(
                "input_raw_list_path is not specified, validating onnx with onnxruntime and random inputs")
            verify_onnx_with_random_inputs(onnx_file_path,
                                            tmp_out_onnx_path,
                                            optimized_special_inputs=out_onnx_special_inputs,
                                            optimized_inputs_preprocs=optimized_inputs_preprocs,
                                            optimized_outputs_postprocs=optimized_outputs_postprocs)

    return out_info
