# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides a pass to permute kv cache input/output 
to make head dim as the first fim
"""
import copy
import re
from typing import Dict

import numpy as np
from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BaseGraphRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import get_value_numeric_shape, logger, safe_replace_all_uses_with


class PermuteKVCacheRewriter(BaseGraphRewriter):
    '''
    Pass to permute head dim and batch dim on kv cache input/output
    '''

    def __init__(self, model, key_name_pattern, value_name_pattern):
        super().__init__(model)
        self.key_name_pattern = re.compile(key_name_pattern)
        self.value_name_pattern = re.compile(value_name_pattern)

    def _permute_input_kv(self, value: ir.Value):
        # set head dim on batch
        origin_shape = get_value_numeric_shape(value)
        perm = list(range(len(origin_shape)))
        perm[0:2] = [1, 0]
        new_shape = np.array(origin_shape)[perm].tolist()
        value.shape = ir.Shape(new_shape)
        transpose_node = ir.Node("",
                                 "Transpose",
                                 inputs=[value],
                                 attributes=[ir.AttrInt64s("perm", perm)],
                                 name=self.graph.meta["extra_info"].get_unique_name_with_suffix(
                                     value.name, "/permute")
                                 )
        transpose_node.outputs[0].name = self.graph.meta["extra_info"].get_unique_name_with_suffix(
            value.name, "/origin_layout")
        self.graph.insert_before(self.graph[0], transpose_node)
        safe_replace_all_uses_with(self.graph, value, transpose_node.outputs[0],
                                   except_users=[transpose_node])
        transpose_node.outputs[0].shape = ir.Shape(origin_shape)
        transpose_node.outputs[0].type = copy.deepcopy(value.type)
        transpose_node.outputs[0].meta["extra_info"] = value.meta["extra_info"].copy(
            ignore_safetensors=True)

        logger.debug("permute input kv cache on '%s'", value.name)
        return value

    def _permute_output_kv(self, value: ir.Value):
        # set head dim on batch
        assert value.name is not None  # check for mypy
        value_name = value.name
        value_shape = get_value_numeric_shape(value)
        perm = list(range(len(value_shape)))
        perm[0:2] = [1, 0]
        transpose_node = ir.Node("",
                                 "Transpose",
                                 inputs=[value],
                                 attributes=[ir.AttrInt64s("perm", perm)],
                                 name=self.graph.meta["extra_info"].get_unique_name_with_suffix(
                                     value.name, "/permute")
                                 )
        self.graph.insert_after(self.graph[-1], transpose_node)
        value.name = self.graph.meta["extra_info"].get_unique_name_with_suffix(
            value.name, "/permute_layout")
        transpose_node.outputs[0].name = value_name
        transpose_node.outputs[0].shape = ir.Shape(
            np.array(value_shape)[perm].tolist())
        transpose_node.outputs[0].type = copy.deepcopy(value.type)
        transpose_node.outputs[0].meta["extra_info"] = value.meta["extra_info"].copy(
            ignore_safetensors=True)

        graph_outputs = self.graph.outputs
        graph_outputs[graph_outputs.index(value)] = transpose_node.outputs[0]

        logger.debug("permute output kv cache on '%s'", value_name)
        return value

    def apply(self):
        rewrite_count = 0
        for input_v in self.graph.inputs[:]:
            if self.key_name_pattern.match(input_v.name) or \
                    self.value_name_pattern.match(input_v.name):
                self._permute_input_kv(input_v)
                rewrite_count += 1
        for output_v in self.graph.outputs[:]:
            if self.key_name_pattern.match(output_v.name) or \
                    self.value_name_pattern.match(output_v.name):
                self._permute_output_kv(output_v)
                rewrite_count += 1
        return rewrite_count

    def preproc_inputs(self, inputs: Dict[str, np.ndarray]):
        """
        Preprocess the inputs data of the original graph,
        return the corresponding inputs data of the kvcache-permuted graph.
        
        Args:
            inputs: inputs of the original graph
        Returns:
            corresponding inputs of the kvcache-permuted graph
        """
        for name, value in inputs.items():
            if self.key_name_pattern.match(name) or \
                    self.value_name_pattern.match(name):
                perm = list(range(len(value.shape)))
                perm[0:2] = [1, 0]
                inputs[name] = value.transpose(*perm)
        return inputs

    def postproc_outputs(self, outputs: Dict[str, np.ndarray]):
        """
        Postproc the outputs data of the kvcache-permuted graph.
        return the corresponding inputs data of the original graph,

        Args:
            outputs: outputs of the kvcache-permuted graph
        Returns:
            corresponding outputs of the original graph
        """
        for name, value in outputs.items():
            if self.key_name_pattern.match(name) or \
                    self.value_name_pattern.match(name):
                perm = list(range(len(value.shape)))
                perm[0:2] = [1, 0]
                outputs[name] = value.transpose(*perm)
        return outputs
