# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""Op-level micro-transforms for concat logic in MHA-to-SHA transformation"""

import onnx_graphsurgeon as gs

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.mha2sha.patterns.base import BasePattern


class ConcatOutputs(BasePattern):
    """Transform to concat the outputs of split SHA nodes when
            1. The output is consumed by a node outside the MHA subgraph (QKV MatMul)
            2. The output is a graph output(past key/value)

                                   Input Tensor
                                   [B, 4, S, D]
                                        |
                                        v
                                  +------------+
                                  |   Node     |
                                  +------------+
                                        |
                                        v
                                   Output: [B, 4, S, D]


                                        ⬇️ Split into 4 SHA Nodes


          [B, 1, S, D]    [B, 1, S, D]    [B, 1, S, D]    [B, 1, S, D]

                |               |              |               |
                v               v              v               v
            +-----------+  +-----------+  +-----------+   +-----------+
            |  Node_0   |  |  Node_1   |  |  Node_2   |   |  Node_3   |
            +-----------+  +-----------+  +-----------+   +-----------+
                |              |              |               |
                v              v              v               v
            [B, 1, S, D]   [B, 1, S, D]   [B, 1, S, D]    [B, 1, S, D]
                \               |             |               /
                 \______________|_____________|______________/
                                       |
                                       |
                                       v
                            +----------------------+
                            |     Concat (axis=1)  | (Concat at axis=0 for past key/value output)
                            +----------------------+
                                       |
                                       v
                                  Output: [B, 4, S, D]

    Where:
        B = Batch size
        4 = Number of heads
        S = Sequence length
        D = Head dimension

    """

    def capture(self, node: gs.Node, **kwargs) -> bool:
        out = node.outputs[0]

        if out in self.graph.outputs:
            self.concat_axis = 0
        elif any(consumer not in self.mha.ordered_nodes for consumer in out.outputs):
            self.concat_axis = 1
        else:
            return False

        return True

    def replace(self, node: gs.Node, **kwargs) -> bool:
        node_output = node.outputs[0]

        outputs = [self.graph.get_tensor_slice(node_output, num) for num in range(self._get_num_groups(node))]

        concat_out = gs.Variable(node_output.name, dtype=node_output.dtype)

        concat = gs.Node(
            name=f"{node.name}",
            op="Concat",
            inputs=outputs,
            outputs=[concat_out],
            attrs={"axis": self.concat_axis},
        )

        # Consumers outside MHA structure (example, output of QKV, consumer of anchor network etc)

        consumers = [consumer for consumer in node_output.outputs if consumer not in self.mha.ordered_nodes]
        for consumer in consumers:
            for index, inp in enumerate(consumer.inputs):
                if inp == node.outputs[0]:
                    consumer.inputs[index] = concat_out

        if node_output in self.graph.outputs:
            self.graph.outputs.remove(node_output)
            self.graph.outputs.append(concat_out)

        self.graph.nodes.append(concat)

        return True
