# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for reordering 
(binary_elewise_op->GroupSlice) -> (GroupSlice->binary_elewise_op)
"""


from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.base_rewriter import M2sBaseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.utils import (
    BroadcastHelper,
    get_gslice_attrs,
    is_reorderable_group_slice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    check_static_shape_of_node_io,
    get_value_numeric_shape,
    logger,
)


class M2sReorderBinElewiseGroupslice(M2sBaseRewriter):
    '''
    Transform subgraph:
        Subgraph(in_a, in_b) --> c, c0,c1,c2...
        {
            c = BinElewiseOp(in_a, in_b)
            c0,c1,c2... = GroupSlice(c)
        }
    Into:
        Subgraph(in_a, in_b) --> c, c0,c1,c2...
        {
            a0,a1,a2,... = Slice(in_a)
            b0,b1,b2,... = Slice(in_b)
            
            c0 = BinElewiseOp(in_a0, in_b0)
            c1 = BinElewiseOp(in_a1, in_b1)
            c2 = BinElewiseOp(in_a2, in_b2)
            
            # if possible
            c = concat(c0,c1,c2,...) 
            # or c = BinElewiseOp(in_a, in_b)

        }

    Broadcasting will be automatically updated
    Also, encodings are updated
    '''
    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.support_op_types = ["Add", "Sub", "Mul", "Div", "Equal", "Pow"]
        self.op_node :ir.Node|None = None

    def match(self, node: ir.Node) -> bool:
        gslice_node = node
        if not is_reorderable_group_slice(gslice_node):
            return False
        assert gslice_node.inputs[0] is not None # check for mypy, definitely true
        op_node = gslice_node.inputs[0].producer()
        assert op_node is not None # check for mypy, definitely true
        if op_node.op_type not in self.support_op_types:
            return False

        check_static_shape_of_node_io(op_node)
        self.op_node = op_node
        return True


    def rewrite(self, node: ir.Node) -> bool:
        gslice_node = node
        assert self.op_node is not None # check for mypy, definitely true

        in_a, in_b = self.op_node.inputs

        bc_helper = BroadcastHelper(
            get_value_numeric_shape(in_a), get_value_numeric_shape(in_b),
            get_value_numeric_shape(self.op_node.outputs[0]),
        )
        gslice_attrs = get_gslice_attrs(gslice_node)
        a_groupslice_attrs = bc_helper.get_input_group_attrs(0, gslice_attrs)
        b_groupslice_attrs = bc_helper.get_input_group_attrs(1, gslice_attrs)

        _ = self.rewrite_based_on_gslice_attrs(
            self.op_node, [gslice_node],
            inputs_gslice_attrs=[a_groupslice_attrs, b_groupslice_attrs]
        )


        logger.debug("applied pass %s on '%s'", self.get_curr_pass_name(), self.op_node.name)

        self.op_node = None
        return True
        