# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""Implements GraphManager class, a utility class implementing common graph-traversal algorithms"""

from collections import defaultdict, deque
from typing import Callable

import numpy as np
import onnx
import onnx_graphsurgeon as gs


class GraphManager:
    """A utility class implementing common graph-traversal algorithms"""

    def __init__(self, model: onnx.ModelProto) -> None:
        self.graph = gs.import_onnx(model)

        self.tensor_value_cache = {}
        self.tensors = {}
        self.tensor_mapping = defaultdict(list)

    @property
    def nodes(self):
        return self.graph.nodes

    @property
    def inputs(self):
        return self.graph.inputs

    @property
    def outputs(self):
        return self.graph.outputs

    def get_tensor_value(self, tensor: gs.Constant | gs.Variable) -> np.ndarray:
        """
        Get the value of `tensor`.
        Cache the values in "tensor_value_cache" as "tensor.values" evaluation is expensive

        Args:
            tensor - Either a gs.Constant input or an output (gs.Variable) of a Node of op_type "Constant"

        Returns:
            A numpy array which is atleast 1d (Not a scalar value)

        Raises:
            TypeError - If type of tensor is invalid
            AttributeError - If tensor is neither a gs.Constant object nor is the output of a Constant node

        """
        if not tensor.name in self.tensor_value_cache:
            if isinstance(tensor, gs.Variable):
                if tensor.inputs[0].op != "Constant":
                    raise AttributeError(
                        f"tensor {tensor.name} is neither a gs.Constant object nor is the output of a Constant node"
                    )
                self.tensor_value_cache[tensor.name] = np.atleast_1d(tensor.inputs[0].attrs["value"].values)
            elif isinstance(tensor, gs.Constant):
                self.tensor_value_cache[tensor.name] = np.atleast_1d(tensor.values)
            else:
                raise TypeError(
                    f"Invalid type for `tensor`: {type(tensor)}. Expected one of: gs.Constant or gs.Variable"
                )
        return self.tensor_value_cache[tensor.name]

    @staticmethod
    def _get_slice_name(tensor: str | gs.Constant | gs.Variable, num: int) -> str:
        """Get the name of sliced tensor based on original tensor `tensor` and slice num `num`

        Args:
            tensor - Base tensor to create the slice tensor from
            num - Slice number
        Returns:
            A string augmenting the tensor name and slice number
        """
        if isinstance(tensor, str):
            name = tensor
        else:
            name = tensor.name
        return f"{name}/:{num}"

    def get_tensor_slice(
        self, tensor: str | gs.Constant | gs.Variable, slice_num: int, values: np.ndarray | None = None
    ) -> gs.Constant | gs.Variable:
        """Get a slice of the original tensor 'tensor' based on tensor name slice_num


        Args:
            tensor - Base tensor to create the slice tensor from
            num - Slice number
            values (Optional) - Numpy array to create a gs.Constant from
        Returns:
            gs.Constant tensor if 'values' is passed, else gs.Variable
        """

        name = tensor if isinstance(tensor, str) else tensor.name
        slice_name = self._get_slice_name(tensor, slice_num)

        if not (tensor_slice := self.tensors.get(slice_name)):
            if values is not None:
                tensor_slice = gs.Constant(name=slice_name, values=values)
            else:
                tensor_slice = gs.Variable(name=slice_name)
            self.tensors[slice_name] = tensor_slice
            self.tensor_mapping[name].append(slice_name)

        return tensor_slice

    def has_slice_tensor(self, tensor: str | gs.Constant | gs.Variable, num: int) -> bool:
        return self._get_slice_name(tensor, num) in self.tensors

    @staticmethod
    def find_upstream_node(
        node: gs.Node, condition: Callable, hard_stop_condition: Callable | None = None
    ) -> gs.Node | None:
        """Starting from node 'node', find the first node in the direction of DAG that satisfies 'condition'

        If a node returns False for 'condition' but True for 'hard_stop_condition',
        no more producers of that node is added to the traversal queue

        Args:
            node: Starting node
            condition:
                A Callable that accepts a gs.Node argument and returns a boolean
                If condition returns True, the node in consideration is returned
            hard_stop_condition:
                A Callable that accepts a gs.Node argument and returns a boolean
                If condition returns True, the producers are not added to the traversal queue

        Returns:
            The first node that satisfies the condition 'condition'
            If no nodes satisfy the given condition, returns None

        """
        producers = []
        for inp in node.inputs:
            producers.extend(inp.inputs)

        q = deque(producers)

        while q:
            curr_node = q.popleft()

            if condition(curr_node):
                return curr_node

            elif hard_stop_condition and hard_stop_condition(curr_node):
                continue

            else:
                for inp in curr_node.inputs:
                    q.extend(inp.inputs)

        return None

    @staticmethod
    def find_downstream_node(
        node: gs.Node, condition: Callable, hard_stop_condition: Callable | None = None
    ) -> gs.Node | None:
        """Starting from node 'node', find the first node in the opposite-direction of DAG that satisfies 'condition'

        If a node returns False for 'condition' but True for 'hard_stop_condition',
        no more consumers of that node is added to the traversal queue

        Args:
            node: Starting node
            condition:
                A Callable that accepts a gs.Node argument and returns a boolean
                If condition returns True, the node in consideration is returned
            hard_stop_condition:
                A Callable that accepts a gs.Node argument and returns a boolean
                If condition returns True, the consumers are not added to the traversal queue

        Returns:
            The first node that satisfies the condition 'condition'
            If no nodes satisfy the given condition, returns None

        """

        consumers = []
        for output in node.outputs:
            consumers.extend(output.outputs)

        q = deque(consumers)

        while q:
            curr_node = q.popleft()

            if condition(curr_node):
                return curr_node

            elif hard_stop_condition and hard_stop_condition(curr_node):
                continue

            else:
                for output in curr_node.outputs:
                    q.extend(output.outputs)

        return None

    @staticmethod
    def is_constant_tensor(tensor: gs.Constant | gs.Variable):
        if isinstance(tensor, gs.Constant):
            return True

        try:
            return tensor.inputs[0].op == "Constant"
        except IndexError:
            return False

    @staticmethod
    def is_linear(node: gs.Node):
        """
        Whether 'node' is a Linear node


        Node is Linear if both the conditions are True:
                1. node is either a MatMul, Gemm, or Conv, and
                2. node has one constant input (projection weights)

        Args:
            node: Node to check

        Returns:
            True, if 'node' is a Linear node, else False
        """

        return node.op in ["MatMul", "Gemm", "Conv"] and GraphManager.is_constant_tensor(node.inputs[1])

    def can_visit(self, src: gs.Node, dst: gs.Node) -> bool:
        """
        Returns True if dst can be visited from src in the graph

        Args:
            src: The starting node
            dst: Whether this node can be reached starting from src

        returns
            True, if dst can be reached starting from src, else False
        """
        node_idx = {node.name: idx for idx, node in enumerate(self.graph.nodes)}

        # Quick check
        dst_idx = node_idx[dst.name]
        if node_idx[src.name] > dst_idx:
            return False

        queue = deque([src])
        while queue:
            curr = queue.popleft()
            if curr == dst:
                return True

            consumers = [c for output in curr.outputs for c in output.outputs if node_idx[c.name] <= dst_idx]
            if dst in consumers:
                return True

            queue.extend(consumers)

        return False

    def is_residual_add(self, node: gs.Node) -> bool:
        """
        Whether `node` is a residual Add node

        A residual add node is an add node that adds the output of previous "layer" with the output of the current layer
        before feeding it as input to the next layer.

        Residual connections was popularized by ResNet in 2015 to solve vanishing gradient problem
        For us, it gives us a unique pattern to identify the exact boundary between two "layers" in the graph

                    | layer_1_output
                    |
                   Add (l1: layer_1_output + l0)
              _ _ _ v
              |  .. layer 2 ..
              |
              |     |
              |     | layer_2_output
              |     |
              |     v
              |- - Add  (l2: layer_2_output + l1)
                    |
              _ _ _ v
              | .. layer 3 ..


        This is particularly useful for splitting transformations like MHA2SHA and PD splitter,
        where the splits can be between the boundaries of these residual add nodes

        Args:
            node: Node to check

        Returns:
            True, if node is residual add, else False

        """

        try:
            if node.op == "Add":
                if node.i().op in ["Add", "Gather"]:
                    return self.can_visit(node.i(), node)
                elif node.i(1).op in ["Add", "Gather"]:
                    return self.can_visit(node.i(1), node)
        except IndexError:
            return False

        return False
