# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""Op-level micro-transforms for slicing logic in MHA-to-SHA transformation"""

import math

import numpy as np
import onnx_graphsurgeon as gs

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.graph_manager import GraphManager
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.mha2sha.patterns.base import BasePattern


class SlicePattern(BasePattern):
    """Default transform to replicate a node across multiple attention heads
    by slicing its inputs and outputs accordingly

                                   Input Tensor
                                   [B, 4, S, D]
                                        |
                                        v
                                    +--------+
                                    |  Node  |
                                    +--------+
                                        |
                                        v
                                   [B, 4, S, D]


                                        ⬇️ Split into 4 SHA Nodes


             [B, 1, S, D]    [B, 1, S, D]    [B, 1, S, D]    [B, 1, S, D]
                  |               |               |               |
                  v               v               v               v
              +--------+      +--------+      +--------+      +--------+
              | Node_0 |      | Node_1 |      | Node_2 |      | Node_3 |
              +--------+      +--------+      +--------+      +--------+
                  |               |               |               |
                  v               v               v               v
             [B, 1, S, D]    [B, 1, S, D]    [B, 1, S, D]    [B, 1, S, D]

    Where:
        B = Batch size
        4 = Number of heads
        S = Sequence length
        D = Head dimension
    """

    def capture(self, node: gs.Node, **kwargs) -> bool:
        """Check if node can be sliced across attention heads"""
        return self._get_num_groups(node) != 0

    def replace(self, node: gs.Node, **kwargs) -> bool:
        """Replace node with 'num_heads' or 'num_groups' nodes across attention heads"""
        for num in range(0, self._get_num_groups(node)):
            node_inputs = self.get_node_inputs(node, num)
            node_outputs = self.get_node_outputs(node, num)
            node_attrs = self.get_attrs(node)
            slice_node = gs.Node(
                name=f"{node.name}/:{num}",
                op=node.op,
                inputs=node_inputs,
                outputs=node_outputs,
                attrs=node_attrs,
                domain=node.domain,
            )

            self.graph.nodes.append(slice_node)

        self.graph.nodes.remove(node)

        return True

    def get_node_inputs(self, node: gs.Node, num: int, **kwargs) -> list[gs.Tensor]:
        """Get sliced inputs for a specific head/group 'num'"""
        node_inputs = []
        for inp in node.inputs:
            if GraphManager.is_constant_tensor(inp) or inp in self.graph.inputs:
                node_inputs.append(inp)
            else:
                if self.graph.has_slice_tensor(inp, num):
                    slice_tensor = self.graph.get_tensor_slice(inp, num)
                    node_inputs.append(slice_tensor)
                else:
                    node_inputs.append(inp)
        return node_inputs

    def get_node_outputs(self, node: gs.Node, num: int, **kwargs) -> list[gs.Tensor]:
        """Get sliced outputs for a specific head/group 'num'"""
        return [self.graph.get_tensor_slice(output.name, num) for output in node.outputs]

    def get_attrs(self, node: gs.Node, **kwargs) -> dict:
        """Get attributes of the sliced node based on 'node'"""
        return node.attrs


class SliceInput(SlicePattern):
    """Transform for slicing input tensors across attention heads

                                       Input Tensor
                                       [B, 4, S, D]
                                            |
                                            v
                                        +--------+
                                        |  Node  |
                                        +--------+
                                            |
                                            v
                                       [B, 4, S, D]


                                            ⬇️ Split into 4 SHA Nodes


                                        Input Tensor
                                        [B, 4, S, D]

                                             |
                                             v
                 --------------------------------------------------------
                 |                  |                  |                |
                 v                  v                  v                v
            +-----------+     +-----------+     +-----------+     +-----------+
            |  Slice_0  |     |  Slice_1  |     |  Slice_2  |     |  Slice_3  |
            +-----------+     +-----------+     +-----------+     +-----------+
                 |                  |                  |                |
                 v                  v                  v                v
            [B, 1, S, D]       [B, 1, S, D]       [B, 1, S, D]      [B, 1, S, D]
                 |                  |                  |                |
                 v                  v                  v                v
            +-----------+     +-----------+     +-----------+       +-----------+
            |  Node_0   |     |  Node_1   |     |  Node_2   |       |  Node_3   |
            +-----------+     +-----------+     +-----------+       +-----------+
                 |                  |                  |                |
                 v                  v                  v                v
            [B, 1, S, D]       [B, 1, S, D]       [B, 1, S, D]      [B, 1, S, D]


    Where:
        B = Batch size
        4 = Number of heads
        S = Sequence length
        D = Head dimension
    """

    def capture(self, node: gs.Node, **kwargs) -> bool:
        """Identify input indices that require slicing"""
        self.slice_idx_map = {}

        for idx, inp in enumerate(node.inputs):
            if not GraphManager.is_constant_tensor(inp):
                if self.graph._get_slice_name(inp.name, 0) not in self.graph.tensors:
                    if len(inp.shape) > 1 and inp.shape[1] != 1:
                        # NOTE: Hard-coding KV concat at batch dim
                        if node.op in ["Concat", "ScatterElements"] and inp in self.graph.inputs:
                            self.slice_idx_map[idx] = 0
                            inp.shape[0], inp.shape[1] = inp.shape[1], inp.shape[0]
                        else:
                            self.slice_idx_map[idx] = 1

        return bool(self.slice_idx_map)

    def get_node_inputs(self, node: gs.Node, num: int, **kwargs) -> list[gs.Tensor]:
        """Get sliced inputs for the node based on slice index mapping"""
        node_inputs = []
        for idx, inp in enumerate(node.inputs):
            if idx in self.slice_idx_map:
                _tensor_name = lambda name: f"{inp.name}_slice_{name}"

                # TODO: Refactor
                slice_output_name = self.graph._get_slice_name(_tensor_name("output"), num)
                if not (slice_output := self.graph.tensors.get(slice_output_name)):
                    slice_output = self.graph.get_tensor_slice(_tensor_name("output"), num)

                    axis = self.slice_idx_map[idx]

                    starts = self.graph.get_tensor_slice(_tensor_name("starts"), num, np.array([num]))
                    ends = self.graph.get_tensor_slice(_tensor_name("ends"), num, np.array([num + 1]))
                    axes = self.graph.get_tensor_slice(_tensor_name("axes"), num, np.array([axis]))
                    steps = self.graph.get_tensor_slice(_tensor_name("steps"), num, np.array([1]))

                    slice = gs.Node(
                        name=f"{inp.name}_slice/:{num}",
                        op="Slice",
                        inputs=[inp, starts, ends, axes, steps],
                        outputs=[slice_output],
                    )

                    self.graph.nodes.append(slice)

                node_inputs.append(slice_output)

            else:
                slice_name = self.graph._get_slice_name(inp.name, num)
                if not (act_tensor := self.graph.tensors.get(slice_name)):
                    act_tensor = self.graph.get_tensor_slice(inp, num)

                node_inputs.append(act_tensor)

        return node_inputs


class SliceMatmul(SlicePattern):
    """Transform for slicing QK and QKV matmul nodes"""

    def capture(self, node: gs.Node, **kwargs) -> bool:
        """Check if the node is one of the QK or QKV matmuls"""
        return node in [self.mha.qk, self.mha.qkv]

    def get_node_inputs(self, node: gs.Node, num: int, **kwargs) -> list[gs.Tensor]:
        """Slice inputs based on group size for QK/QKV matmuls"""
        i1_n_groups = self._get_num_groups(node.i())
        i2_n_groups = self._get_num_groups(node.i(1))

        i1_g_size = self.mha.num_heads // i1_n_groups
        i2_g_size = self.mha.num_heads // i2_n_groups

        if concat := self.graph.find_upstream_node(
            node,
            condition=lambda _node: _node.op in ["Concat", "ScatterElements"]
            and any(_inp in self.graph.inputs for _inp in _node.inputs),
            hard_stop_condition=lambda _node: _node.op not in ["Reshape", "Expand", "Mul"],
        ):
            second_input = concat.outputs[0]
        else:
            second_input = node.inputs[1]
        return [
            self.graph.get_tensor_slice(node.inputs[0], num // i1_g_size),
            self.graph.get_tensor_slice(second_input, num // i2_g_size),
        ]


class SliceProjection(SlicePattern):
    """Transform for slicing projection and its weights across attention heads

                                         Input Tensor
                                           [B, 128]
                                               |
                                               v
                                        +----------------+
                                        |    Linear      |
                                        +----------------+
                                        | W: [128, 128]  |
                                        | B: [128]       |
                                        +----------------+
                                               |
                                               v
                                             [B, 128]


                                                ⬇️ Split into 4 SHA Nodes


                                        Input Tensor
                                           [B, 128]
                                               |
                ---------------------------------------------------------------------
                |                         |                     |                   |
                v                         v                     v                   v
            +------------------+ +------------------+ +------------------+ +------------------+
            |    Linear_0      | |     Linear_1     | |     Linear_2     | |     Linear_3     |
            +------------------+ +------------------+ +------------------+ +------------------+
            | W: [128, 32]     | | W: [128, 32]     | | W: [128, 32]     | | W: [128, 32]     |
            | B: [32]          | | B: [32]          | | B: [32]          | | B: [32]          |
            +------------------+ +------------------+ +------------------+ +------------------+
                    |                    |                       |                  |
                    v                    v                       v                  v
                 [B, 32]              [B, 32]                 [B, 32]            [B, 32]


    Where:
        B = Batch size
        4 = Number of heads
        S = Sequence length
        D = Head dimension
    """

    def capture(self, node: gs.Node, **kwargs) -> bool:
        """Check if the node is one of the Q, K or V projections"""
        return node in self.mha.linears

    def get_node_inputs(self, node: gs.Node, num: int, **kwargs) -> list[gs.Tensor]:
        """Slice weights (and bias, if applicable) for the projection layer"""
        weight_tensor = node.inputs[1]

        weights = self.graph.get_tensor_value(weight_tensor)

        head_dim = math.ceil(node.inputs[1].shape[0] / self._get_num_groups(node))

        begin = num * head_dim
        end = (num + 1) * head_dim

        if node.op == "Conv" or (node.op == "Gemm" and node.attrs.get("transB") == 1):
            weights = weights[begin:end]
        else:  # MatMul or Gemm
            weights = weights[:, begin:end]

        # Create weights Constant
        weights_constant = self.graph.get_tensor_slice(weight_tensor.name, num, weights)

        node_inputs = [node.inputs[0], weights_constant]

        # If bias exists
        try:
            bias_tensor = node.inputs[2]
            bias = self.graph.get_tensor_value(bias_tensor)
            bias = bias[begin:end]
            bias_constant = self.graph.get_tensor_slice(bias_tensor.name, num, bias)
            node_inputs.append(bias_constant)
        except IndexError:
            # No bias
            pass

        return node_inputs


class SliceConst(SlicePattern):
    """Transform for slicing nodes and it's constant input across attention heads


                                                |
                                                v
                                         +--------------+
                                         |     Node     |
                                         +--------------+
                                         | [B, 4, S, D] |
                                         +--------------+


                                             ⬇️ Split into 4 SHA Nodes


                    |              |                |                 |
                    v              v                v                 v
            +-------------+  +-------------+  +-------------+  +-------------+
            |   Node_0    |  |   Node_1    |  |   Node_2    |  |   Node_3    |
            +-------------+  +-------------+  +-------------+  +-------------+
            | [B, 1, S, D]|  | [B, 1, S, D]|  | [B, 1, S, D]|  | [B, 1, S, D]|
            +-------------+  +-------------+  +-------------+  +-------------+


    Where:
        B = Batch size
        4 = Number of heads
        S = Sequence length
        D = Head dimension
    """

    def capture(self, node: gs.Node, **kwargs) -> bool:
        self.index_map = {}

        self.num_groups = self._get_num_groups(node)

        upstream_linear = self._get_upstream_linear(node)

        if upstream_linear.op == "Conv" or (
            upstream_linear.op == "Gemm" and upstream_linear.attrs.get("transB") == 1
        ):
            match_shape = upstream_linear.inputs[1].shape[0]
        else:
            match_shape = upstream_linear.inputs[1].shape[1]  # MatMul and Gemm

        for idx, inp in enumerate(node.inputs):
            if GraphManager.is_constant_tensor(inp):
                for i, s in enumerate(inp.shape):
                    if s == self.num_groups or s == match_shape:
                        self.index_map[idx] = i
                        break

        return bool(self.index_map)

    def get_node_inputs(self, node: gs.Node, num: int, **kwargs) -> list[gs.Tensor]:
        node_inputs = []
        for idx, inp in enumerate(node.inputs):
            if (shape_idx := self.index_map.get(idx)) is not None:
                const_value = self.graph.get_tensor_value(inp)
                skip = inp.shape[shape_idx] // self.num_groups
                const_slice_value = const_value.take(
                    indices=range(skip * num, skip * (num + 1)), axis=shape_idx
                )
                const_slice_inp = self.graph.get_tensor_slice(inp, num, const_slice_value)
                node_inputs.append(const_slice_inp)
            else:
                node_inputs.append(self.graph.get_tensor_slice(inp, num))

        return node_inputs


class SliceReshape(SlicePattern):
    """Transform for slicing Reshape nodes across attention heads

                                                  |
                                                  v
                                         +-----------------+
                                         |     Reshape     |
                                         +-----------------+
                                         | shape=(B,4,S,D) |
                                         +-----------------+


                                             ⬇️ Split into 4 SHA Nodes


                    |                   |                 |                 |
                    v                   v                 v                 v
            +---------------+  +---------------+  +---------------+  +---------------+
            |   Reshape_0   |  |   Reshape_1   |  |   Reshape_2   |  |   Reshape_3   |
            +---------------+  +---------------+  +---------------+  +---------------+
            |shape=(B,1,S,D)|  |shape=(B,1,S,D)|  |shape=(B,1,S,D)|  |shape=(B,1,S,D)|
            +---------------+  +---------------+  +---------------+  +---------------+

    Where:
        B = Batch size
        4 = Number of heads
        S = Sequence length
        D = Head dimension
    """

    def capture(self, node: gs.Node, **kwargs) -> bool:
        if node.op in ["Reshape"] and super().capture(node, **kwargs):
            self.shape = self.graph.get_tensor_value(node.inputs[1])
            return self._get_num_groups(node) in self.shape

        return False

    def get_node_inputs(self, node: gs.Node, num: int, **kwargs) -> list[gs.Tensor]:
        shape = np.array([1 if s == self._get_num_groups(node) else s for s in self.shape])

        reshape_input = self.graph.get_tensor_slice(
            node.inputs[0],
            num,
        )
        shape_const = self.graph.get_tensor_slice(node.inputs[1], num, shape)

        node_inputs = [reshape_input, shape_const]

        return node_inputs
