# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the base ir rewriter class
"""
from abc import ABC, abstractmethod

from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.ir_extra_info import GraphExtraInfo, VariableExtraInfo
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    get_shape_of_slice,
    get_value_numeric_shape,
    has_static_shape_on_value,
)


class BaseGraphRewriter(ABC):
    """
    Base class for the ir graph rewriter,
    used for ir modification
    """

    def __init__(self, graph: ir.Graph):
        self.graph = graph
        self.curr_pass_rewrite_uid = 0

    @abstractmethod
    def apply(self):
        """
        Apply the rewriter on the graph
        """
        # curr_pass_rewrite_uid should be incremented in the overriding method

    def get_curr_pass_name(self):
        """
        Get name of the current pass
        every pass has a unique name
        """
        return self.__class__.__name__+f"[{self.curr_pass_rewrite_uid}]"

    def mark_value_as_copy(self, copy_from: ir.Value, value: ir.Value):
        """
        Record the value is copied from copy_from, 
        which means they should have 
        - same numerical value
        - shape
        - dtype
        - same encodings
        - same safetensors if they have
        - same updatable attribute if they have

        Args:
            copy_from: the tensor copied from
            value: the tensor copied to
        """
        assert isinstance(copy_from.meta["extra_info"], VariableExtraInfo)
        assert isinstance(self.graph.meta["extra_info"], GraphExtraInfo)
        value.meta["extra_info"] = copy_from.meta["extra_info"].copy()
        self.graph.meta["extra_info"].record_copy(copy_from.name, value.name,
                                                  self.get_curr_pass_name())
        value.shape = copy_from.shape
        if copy_from.dtype is not None:
            value.dtype = copy_from.dtype

    # pylint: disable=[too-many-arguments, too-many-positional-arguments]
    def mark_value_as_slice(self,
                            slice_from: ir.Value, value: ir.Value,
                            axis, start, end, batch_slice_id, head_slice_id):
        """
        Record the value is a slice of slice_from, 
        shape/dtype of value will be infered automatically

        Args:
            slice_from: the tensor slice from
            value: the tensor sliced to
            axis: the slice axis
            start: start of the slice
            end: end of the slice
            batch_slice_id: the batch slice id
            head_slice_id: the head slice id
        """

        value.meta["extra_info"] = slice_from.meta["extra_info"].slice(
            axis, start, end)
        self.graph.meta["extra_info"].record_slicing(slice_from.name,
                                                     value.name,
                                                     self.get_curr_pass_name(),
                                                     axis, start, end, batch_slice_id, head_slice_id)
        # shape infer
        if value.shape is None and has_static_shape_on_value(slice_from):
            value.shape = ir.Shape(get_shape_of_slice(get_value_numeric_shape(slice_from),
                                                      [axis], [start], [end]))
        if slice_from.dtype is not None:
            value.dtype = slice_from.dtype


class BasePredicateRewriter(BaseGraphRewriter):
    """
    Base rewriter that based on match->rewrite pattern
    """

    @abstractmethod
    def match(self, node: ir.Node) -> bool:
        """
        Check if the node has any possibility to be rewritten
        """

    @abstractmethod
    def rewrite(self, node: ir.Node) -> bool:
        """
        Rewrite the node, return True if the node is rewritten, False otherwise
        """

    def apply(self):
        self.start_whole_graph_analysis()
        rewrite_count = 0
        n_list = list(self.graph)[:]
        for n in n_list:
            if self.match(n):
                if self.rewrite(n):
                    rewrite_count += 1
                    self.curr_pass_rewrite_uid += 1

        self.end_whole_graph_analysis()
        return rewrite_count

    def start_whole_graph_analysis(self):
        """
        Hook to execute before any rewrite is started
        """

    def end_whole_graph_analysis(self):
        """
        Hook to execute after all possible rewrite is finished
        """
