# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass to remove useless mul in the graph
Useless mul is the mul with one input that is 1.0
"""


from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BasePredicateRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    get_constant_np,
    get_value_numeric_shape,
    have_static_shape_on_node_io,
    is_constant,
    logger,
    safe_replace_all_uses_with,
    scan_least_common_ancestor,
)


class NullMulRemovalRewriter(BasePredicateRewriter):
    """
    A graph rewriter pass that removes useless Mul from the graph.
    Useless mul is the mul with one input that is 1.0
    """

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.cst_one_id: int | None = None
        self.input_id: int | None = None

    def match(self, node: ir.Node) -> bool: # pylint: disable=R0911
        if node.op_type != "Mul":
            return False
        if not have_static_shape_on_node_io(node):
            return False

        self.cst_one_id = None
        self.input_id = None
        if is_constant(node.inputs[0]):
            self.cst_one_id = 0
            self.input_id = 1
        elif is_constant(node.inputs[1]):
            self.cst_one_id = 1
            self.input_id = 0
        #  meaningless for other cases

        if self.cst_one_id is None:
            return False


        cst_v = node.inputs[self.cst_one_id]
        assert cst_v is not None  # check for mypy

        cst_v_np = get_constant_np(node.inputs[self.cst_one_id])
        if not (cst_v_np == 1.0).all():
            return False

        if cst_v.meta["extra_info"].is_updatable_weight():
            return False

        # skip this optimization if this node may be a part of layernorm/rmsnorm
        # removing the Mul node in this case can break the matching rules in Converter/Quantizer
        assert self.input_id is not None  # check for mypy
        if self.is_potential_affine_mul_in_norm(node, self.input_id):
            return False

        assert self.input_id is not None  # check for mypy
        input_shape = get_value_numeric_shape(node.inputs[self.input_id])
        output_shape = get_value_numeric_shape(node.outputs[0])

        if input_shape == output_shape:
            # non broadcast
            return True

        return False

    def rewrite(self, node: ir.Node) -> bool:
        assert self.input_id is not None  # check for mypy
        assert node.inputs[0] is not None  # check for mypy
        assert node.outputs[0] is not None  # check for mypy

        safe_replace_all_uses_with(
            self.graph,
            node.outputs[0],
            node.inputs[self.input_id]
        )
        node.outputs[0].meta["extra_info"].merge(
            node.inputs[0].meta["extra_info"])

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), node.name)

        return True

    def is_potential_affine_mul_in_norm(self, node: ir.Node, non_cst_input_id:int) -> bool:
        """
        Checks if the given node could be part of a normalization (e.g. rmsnorm, layernorm).

        Typically the rmsnorm and layernorm has this kind of subgraph connection:
            OP-----------
            |           |
            |          Mul/Pow
            |           |
            |      ReduceMean
            |           |
            |         Sqrt
            |           |
            |          Add
            |           |
            Div--------/
            |
            Mul (the affine mul that this function concerns)
        The whole subgraph is not checked stritly, since there are so many variaties.
        Only the two way paths connection are checked.
    
        Note: This is not an exact check, if with any possibility it is part of a normalization,
        it will return True.

        Note: one of the inputs should be constant, this will not be checked by this function.

        Args:
            node (ir.Node): The node to check.

        Returns:
            bool: True if the node is potentially part of layernorm, False otherwise.
        """
        non_cst_v = node.inputs[non_cst_input_id]
        if non_cst_v is None: # check for mypy
            return False
        converge_node = non_cst_v.producer()
        if converge_node is None:
            return False
        if len(converge_node.inputs) != 2:
            return False
        lca_v = scan_least_common_ancestor(
            converge_node.inputs[0], converge_node.inputs[1],
            max_layers_to_traverse=10, # RMSNorm/LayerNorm should be very small subgraph, so limit the traversal
        )
        if lca_v not in converge_node.inputs:
            # lca should also be a input of the converge node
            return False
        return True
