# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the infrastructure of the post graph optimization,
and also the entry of the mha2sha optimization
"""

import copy
import itertools
import json
import os
from typing import Dict, List

import onnx
import onnxscript
from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.passes.adaption.extract_lora_alpha import (
    LoraAlphaExtractor,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.passes.adaption.permute_kv_cache import (
    PermuteKVCacheRewriter,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.passes.mha2sha.mha2sha_rewriter import (
    MHA2SHARewriter,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.passes.opt.layout_opt.layout_opt import (
    LayoutOptRewriter,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.passes.protect_io import (
    ProtectIO,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.passes.shape_infer.shape_infer import (
    ShapeInference,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.split import split
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.utils.encodings import (
    GraphEncodingInfo,
    TensorEncodingInfo,
    deserialize_encodings,
    infer_encodings_chn_block_axis,
    serialize_graph_encodings,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.utils.ir_extra_info import (
    GraphExtraInfo,
    VariableExtraInfo,
    chain_m2s_tracing_info,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.utils.utils import (
    iter_all_values,
    load_external_data_for_constant,
    logger,
)


class PostQuantGraphOptimizer:
    """
    Base class to perform any post quantization optimization (pre quant is also possible)
    The encodings/safetensors/updatable lists are all embeded into the ir.Graph.
    """

    # pylint: disable=[R0913,R0917]
    def __init__(
        self,
        onnx_proto: onnx.ModelProto,  # external data is not required to load
        named_encodings: Dict[str, Dict] | None = None,
        named_safetensors: Dict[str, Dict] | None = None,
        updatable_tensors: List[str] | None = None,
        base_dir: str | None = None,  # directory of the external data
        naming_prefix: str = "post_quant_opt",
    ):
        """Base class of the graph optimizer that focused on post-quantization optimization.

        Args:
            onnx_proto: The onnx model proto, external data is not required to load.
            named_encodings: a dictionary conatins all the encodings information of the graph
                            key is the enc_set name, value is the corresponding graph encodings
            named_safetensors: a dictionary conatins all the lora saftensors information of the graph
                            key is the enc_set name, value is the corresponding saftensors
            updatable_tensors: The updatable tensor names of the model.
            base_dir: The base directory of the external data.
            naming_prefix: The name prefix of the new node/tensor created by the optimizer.
        """

        # load graph
        self.model_ir = ir.serde.deserialize_model(onnx_proto)
        self.graph_ir = self.model_ir.graph

        if base_dir is not None:
            ir.external_data.set_base_dir(self.model_ir.graph, base_dir)

        if named_encodings is None:
            named_encodings = {}
        if named_safetensors is None:
            named_safetensors = {}
        if updatable_tensors is None:
            updatable_tensors = []

        self.base_dir = base_dir

        named_encodings = dict(named_encodings.items())
        src_named_encodings: Dict[str, GraphEncodingInfo] = {}
        for enc_set_name, encodings in named_encodings.items():
            src_named_encodings[enc_set_name] = deserialize_encodings(encodings)

        self.src_named_encodings = src_named_encodings
        self.src_named_safetensors = named_safetensors
        self.src_updatable_tensors = updatable_tensors

        self.embed_encodings_into_graph(self.src_named_encodings)
        self.embed_safetensors_into_graph(self.src_named_safetensors)
        self.embed_updatable_tensors_into_graph(self.src_updatable_tensors)

        self.graph_ir.meta["extra_info"] = GraphExtraInfo(naming_prefix=naming_prefix)
        self.graph_ir.meta["extra_info"].naming_policy.init_from_graph(self.graph_ir)

        ShapeInference(self.graph_ir, self.model_ir).apply()

        self.infer_encodings_chn_block_axis()

    def infer_encodings_chn_block_axis(self):
        """
        Infer encodings' channel axis and block axis for all the tensors in the graph
        ShapeInference should be already applied
        """
        for v in iter_all_values(self.graph_ir):
            if "extra_info" not in v.meta:
                continue
            for enc in v.meta["extra_info"].named_encodings.values():
                infer_encodings_chn_block_axis(enc, v)

    def embed_updatable_tensors_into_graph(self, updatable_tensors: List[str]):
        """Embed updatable tensors into ir graph

        Args:
            updatable_tensors: The updatable tensor names of the model.
        """
        updatable_tensors_set = set(updatable_tensors)

        for v in iter_all_values(self.graph_ir):
            if "extra_info" not in v.meta:
                v.meta["extra_info"] = VariableExtraInfo()
            if v.name in updatable_tensors_set:
                v.meta["extra_info"].is_updatable = True

    def embed_encodings_into_graph(self, named_encodings: Dict[str, GraphEncodingInfo]):
        """Embed encodings into ir graph

        Args:
            named_encodings: a dictionary conatins all the encodings information of the graph
                            key is the enc_set name, value is the corresponding graph encodings
        """
        for encset_name, graph_enc in named_encodings.items():
            tensors_encodings: Dict[str, TensorEncodingInfo] = {}
            for k, v in itertools.chain(
                graph_enc.activation_encodings.items(), graph_enc.param_encodings.items()
            ):
                if k in tensors_encodings and v != tensors_encodings[k]:
                    assert tensors_encodings[k] == v, (
                        f"Encoding '{k}' is defined twice in param_encodings and activation_encodings"
                    )
                tensors_encodings[k] = v

            for v in iter_all_values(self.graph_ir):
                if "extra_info" not in v.meta:
                    v.meta["extra_info"] = VariableExtraInfo()
                if v.name in tensors_encodings:
                    v.meta["extra_info"].named_encodings[encset_name] = tensors_encodings[v.name]

    def embed_safetensors_into_graph(self, named_safetensors: Dict[str, Dict]):
        """Embed safetensors into ir graph

        Args:
            named_safetensors: a dictionary conatins all the lora saftensors information of the graph
                            key is the enc_set name, value is the corresponding saftensors
        """
        for set_name, safetensors in named_safetensors.items():
            for v in iter_all_values(self.graph_ir):
                if "extra_info" not in v.meta:
                    v.meta["extra_info"] = VariableExtraInfo()
                if v.name in safetensors:
                    v.meta["extra_info"].named_safetensors[set_name] = safetensors[v.name]

    def save_onnx(self, path: str):
        """saved the ir Graph into the onnx file.
        This function is much more memory efficient than onnx.save(get_onnx_proto(load_weights=True))

        Args:
            path: the path to save the onnx file
        """
        # bugfix for onnx_ir,
        # onnx_ir will not handle the path of external data for Constant node in the serialization,
        # so we need to load external data for constant node before serialization
        load_external_data_for_constant(self.model_ir)

        base_dir = os.path.dirname(path)
        os.makedirs(base_dir, exist_ok=True)
        basename = os.path.basename(path)
        onnxscript.ir.save(self.model_ir, path, external_data=basename[:-5] + ".data")
        logger.debug("saved onnx to %s", path)

    def get_onnx_proto(self, load_weights=False):
        """Serialize the model into onnx.ModelProto
        Args:
            load_weights: whether load the external data
        """
        onnx_proto = onnxscript.ir.serde.serialize_model(self.model_ir)
        if load_weights and self.base_dir is not None:
            onnx.load_external_data_for_model(onnx_proto, self.base_dir)
        return onnx_proto

    def get_encodings(self):
        """Extract encodings from the extra_info of each tensor in the ir Graph"""

        named_graph_encodings = {}

        for encset_name, origin_graph_enc in self.src_named_encodings.items():
            graph_enc = GraphEncodingInfo()
            graph_enc.quantizer_args = copy.deepcopy(origin_graph_enc.quantizer_args)
            named_graph_encodings[encset_name] = graph_enc

        for v in iter_all_values(self.graph_ir):
            if "extra_info" in v.meta:
                for encset_name, v_enc in v.meta["extra_info"].named_encodings.items():
                    named_graph_encodings[encset_name].add_tensor_encodings(v.name, v_enc)

        return named_graph_encodings

    def get_safetensors(self):
        """Extract saftensors from the extra_info of each tensor in the ir Graph"""
        named_safetensors = {}
        # extract safetensors from extra_info
        for encset_name in self.src_named_safetensors.keys():
            graph_safetensors = {}
            for v in self.graph_ir.initializers.values():
                if "extra_info" in v.meta and encset_name in v.meta["extra_info"].named_safetensors:
                    graph_safetensors[v.name] = v.meta["extra_info"].named_safetensors[encset_name]
            named_safetensors[encset_name] = graph_safetensors
        return named_safetensors

    def get_updatable_tensor_names(self):
        """Extract updatable tensors from the extra_info of each tensor in the ir Graph"""
        # extract updatable from extra_info
        updatable_tensor_names = []
        for v in iter_all_values(self.graph_ir):
            if "extra_info" in v.meta and v.meta["extra_info"].is_updatable:
                updatable_tensor_names.append(v.name)

        return updatable_tensor_names


class GraphOptimizer(PostQuantGraphOptimizer):
    """Entry class of the mha2sha optimizations
    all mha2sha related optimizations should be called from this class object.
    """

    # pylint: disable=[R0913,R0917]

    def __init__(
        self,
        onnx_proto: onnx.ModelProto,  # external data is not required to load
        named_encodings: Dict[str, Dict] | None = None,
        named_safetensors: Dict[str, Dict] | None = None,
        updatable_tensors: List[str] | None = None,
        base_dir: str | None = None,  # directory of the external data
    ):
        """Entry class of the optimizations
           all mha2sha related optimizations should be called from this class object.

        Args:
            onnx_proto: The onnx model proto, external data is not required to load.
            named_encodings: a dictionary conatins all the encodings information of the graph
                            key is the enc_set name, value is the corresponding graph encodings
            named_safetensors: a dictionary conatins all the lora saftensors information of the graph
                            key is the enc_set name, value is the corresponding saftensors
            updatable_tensors: The updatable tensor names of the model.
            base_dir: The base directory of the external data.
            naming_prefix: The name prefix of the new node/tensor created by the optimizer.
        """

        super().__init__(
            onnx_proto=onnx_proto,
            named_encodings=named_encodings,
            named_safetensors=named_safetensors,
            updatable_tensors=updatable_tensors,
            base_dir=base_dir,
            naming_prefix="m2s",
        )

    def permute_kv_cache_io(self, key_name_pattern, value_name_pattern):
        """
        Permute batch-dim and head-dim of the kvcache input/output

        Args:
            key_name_pattern: the key name re pattern of the kvcache
            value_name_pattern: the value name re pattern of the kvcache
        """
        rewriter = PermuteKVCacheRewriter(self.graph_ir, key_name_pattern, value_name_pattern)
        rewriter.apply()
        return rewriter

    def extract_lora_v2_alpha(self):
        """
        Extract lora v2 alpha, set it as an additional graph input
        """
        lora_alpha_extractor = LoraAlphaExtractor(self.graph_ir)
        lora_alpha_extractor.apply()
        return lora_alpha_extractor.get_lora_alpha_np_value()

    def apply_mha2sha(self, m2s_head_split_map=None):
        """
        Apply mha2sha on the graph
        All the modification should be equivalent transformation
        """
        io_protector = ProtectIO(self.graph_ir)
        io_protector.protect()
        MHA2SHARewriter(self.graph_ir, m2s_head_split_map).apply()
        LayoutOptRewriter(self.graph_ir).apply()
        io_protector.unprotect()

    def split(
        self,
        num_splits: int,
        split_embedding: bool = False,
        split_lm_head: bool = False,
        skip_verification: bool = True,
    ):
        """
        Split the onnx model based on the arguments
        Args:
            num_splits (int): The number of splits to be made
            split_embedding (bool): If True, splits the embeddings. Default is False
            split_lm_head (bool): If True, splits the language model head. Default is False
        """
        splits = split(
            self.model_ir,
            num_splits=num_splits,
            split_embedding=split_embedding,
            split_lm_head=split_lm_head,
            skip_verification=skip_verification,
        )

        splits_info = []

        for _split in splits:
            onnx_proto = onnxscript.ir.serde.serialize_model(_split)

            named_graph_encodings = {}
            named_safetensors = {}
            updatable_tensor_names = []

            for v in iter_all_values(_split.graph):
                if "extra_info" in v.meta:
                    for encset_name, v_enc in v.meta["extra_info"].named_encodings.items():
                        if encset_name not in named_graph_encodings:
                            named_graph_encodings[encset_name] = GraphEncodingInfo()

                        named_graph_encodings[encset_name].add_tensor_encodings(v.name, v_enc)

                    if v.meta["extra_info"].is_updatable:
                        updatable_tensor_names.append(v.name)

            for encset_name in self.src_named_safetensors.keys():
                graph_safetensors = {}

                for v in _split.graph.initializers.values():
                    if "extra_info" in v.meta and encset_name in v.meta["extra_info"].named_safetensors:
                        graph_safetensors[v.name] = v.meta["extra_info"].named_safetensors[encset_name]
                named_safetensors[encset_name] = graph_safetensors

            splits_info.append(
                {
                    "onnx_proto": onnx_proto,
                    "named_encodings": {
                        k: serialize_graph_encodings(v) for k, v in named_graph_encodings.items()
                    },
                    "named_safetensors": named_safetensors,
                    "updatable_tensors": updatable_tensor_names,
                }
            )

        return splits_info

    def save_tracing_info(self, path: str, merged=False):
        """
        Save tracing information to the file

        Args:
            path: the path to save the tracing information
            merged: whether to merge the chainable transformations into one
        """
        tracing_info_j = self.get_tracing_info(merged)

        with open(path, "w", encoding="utf-8") as f:
            json.dump(tracing_info_j, f, indent=4)

    def get_tracing_info(self, merged=False):
        """
        Get tracing information of all transformations recorded

        Args:
            merged: whether to merge the chainable transformations into one
        """
        one2one_tracing_info = self.graph_ir.meta["extra_info"].one2one_tracing_info
        if merged:
            one2one_tracing_info = chain_m2s_tracing_info(one2one_tracing_info)

        tracing_info_j = []
        for _, v in one2one_tracing_info.items():
            tracing_info_j.append(v.as_dict())

        for v in self.graph_ir.meta["extra_info"].subgraph_tracing_info:
            tracing_info_j.append(v.as_dict())
        return tracing_info_j
