# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""Base class for MHA pattern definition and op-level pattern matching"""

from abc import ABCMeta
from dataclasses import dataclass

import onnx_graphsurgeon as gs

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.graph_manager import GraphManager
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.pattern import Pattern


@dataclass
class MHA:
    """
    Dataclass definition for Multi-Head Attention (MHA) subgraph

    Attributes:
        qk: Node representing the QK matmul operation
        qkv: Node representing the QKV matmul operation
        linears: List of linear projection nodes (e.g., MatMul, Gemm, Conv) used for Q, K, V
        ordered_nodes: List of nodes in topological order that form the MHA subgraph
    """

    qk: gs.Node
    qkv: gs.Node
    linears: list[gs.Node]
    ordered_nodes: list[gs.Node]

    @property
    def num_heads(self):
        """Returns the number of attention heads inferred from the QKV output shape"""
        return self.qkv.outputs[0].shape[1]

    @property
    def head_dim(self):
        """Returns the number of embedding dimensions per attention head"""
        return self.qkv.outputs[0].shape[-1]


class BasePattern(Pattern, metaclass=ABCMeta):
    """Base class for op-level/micro transforms for MHA2SHA"""

    def __init__(self, graph: GraphManager, mha: MHA) -> None:
        """
        Args:
            graph: GraphManager instance representing the ONNX graph
            mha: MHA dataclass instance containing key nodes and metadata
        """

        super().__init__(graph)
        self.mha = mha

    def _get_upstream_linear(self, node: gs.Node) -> gs.Node | None:
        """Find the nearest upstream linear node (MatMul, Gemm, or Conv)

        Args:
            node: Node from which to begin the upstream search

        Returns:
            The first upstream node that qualifies as a linear operation, or
            None if no such node is found
        """

        if self.graph.is_linear(node):
            return node
        else:
            return self.graph.find_upstream_node(
                node,
                condition=self.graph.is_linear,
                hard_stop_condition=lambda _node: _node not in self.mha.ordered_nodes,
            )

    def _get_num_groups(self, node: gs.Node) -> int:
        """Determine the number of groups for Grouped-Query Attention
           For non-GAQ case, this is the same as the number of heads

        Args:
            node: Node for which to determine group count

        Returns:
            Number of groups based on the node's relation to the MHA pattern
            Returns 0 if the group count cannot be determined
        """

        if node in [self.mha.qk, self.mha.qkv] or self.graph.find_upstream_node(
            node, condition=lambda _node: _node is self.mha.qk, hard_stop_condition=self.graph.is_linear
        ):
            return self.mha.num_heads
        elif linear := self._get_upstream_linear(node):
            shape = linear.inputs[1].shape
            if linear.op == "Conv" or (linear.op == "Gemm" and linear.attrs.get("transB") == 1):
                return shape[0] // self.mha.head_dim
            else:
                return shape[1] // self.mha.head_dim

        else:
            return 0
