# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================

from collections import deque

import onnx_graphsurgeon as gs

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.mha2sha.patterns import (
    MHA,
    ConcatOutputs,
    SliceConst,
    SliceInput,
    SliceMatmul,
    SlicePattern,
    SliceProjection,
    SliceReshape,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.pattern import Pattern

ALLOWED_OP_TYPES = ["Transpose", "Softmax", "Add", "Div", "Mul", "Where", "Cast"]


class MHAPattern(Pattern):
    def capture(self, node: gs.Node, **kwargs) -> bool:
        def _is_matmul(_node: gs.Node) -> bool:
            if _node.op == "MatMul" and not any(self.graph.is_constant_tensor(inp) for inp in _node.inputs):
                return True

            return False

        def dfs_reverse(root: gs.Node):
            if root.name in reverse_visited or root in linears:
                return

            reverse_visited.add(root.name)

            for inp in root.inputs:
                for producer in inp.inputs:
                    dfs_reverse(producer)
            reverse_ordered_nodes.append(root)

        def dfs_forward(root: gs.Node):
            if root.name in visited or root == qkv:
                return

            visited.add(root.name)

            for output in root.outputs:
                for consumer in output.outputs:
                    dfs_forward(consumer)
            ordered_nodes.append(root)

        if _is_matmul(node):
            try:
                prev, curr = node, node.o()
            except IndexError:  # If we encounter a MatMul which is graph output
                return False

            while curr.op in ALLOWED_OP_TYPES or not _is_matmul(curr):
                prev = curr
                try:
                    curr = curr.o()
                except IndexError:
                    return False

            if curr.op != "MatMul" or prev == node:  # MatMul -> MatMul case
                return False

            linears = []

            qk = node
            qkv = curr

            starting_nodes = [
                qk.i(),
                qk.i(1),
                qkv.i() if qkv.i(1) == prev else qkv.i(1),
            ]

            q = deque(starting_nodes)

            while q:
                curr = q.popleft()
                if self.graph.is_linear(curr):
                    if curr not in linears:
                        linears.append(curr)

                else:
                    # TODO: Handle exception cases
                    for inp in curr.inputs:
                        q.extend(inp.inputs)

            if linears:
                visited = set()
                reverse_visited = set()
                ordered_nodes = []
                reverse_ordered_nodes = []

                for start in linears:
                    dfs_forward(start)

                dfs_reverse(qkv)

                ordered_nodes.reverse()
                ordered_nodes.append(qkv)

                ordered_nodes = linears + [n for n in ordered_nodes if n in reverse_ordered_nodes]

                self.mha = MHA(qk, qkv, linears, ordered_nodes)
                return True

        return False

    def replace(self, node: gs.Node, **kwargs) -> bool:
        for curr in self.mha.ordered_nodes:
            patterns = [
                SliceProjection,
                SliceReshape,
                SliceConst,
                SliceInput,
                SliceMatmul,
                SlicePattern,
            ]

            for pattern in patterns:
                matcher = pattern(self.graph, self.mha)
                if matcher.capture(curr):
                    matcher.replace(curr)
                    break

        for curr in self.mha.ordered_nodes:
            pattern = ConcatOutputs(self.graph, self.mha)
            if pattern.capture(curr):
                pattern.replace(curr)

        return True
