# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for reordering 
(ScatterElements->GroupSlice) -> (GroupSlice->ScatterElements)
"""



from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.base_rewriter import M2sBaseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.utils import (
    get_gslice_attrs,
    is_reorderable_group_slice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    check_static_shape_of_node_io,
    convert_attr_to_py,
    logger,
)


class M2sReorderScatterElementsGroupslice(M2sBaseRewriter):
    '''
    Transform subgraph:
        Subgraph(data, indices, updates) --> c, c0,c1,c2...
        {
            c = ScatterElements(data, indices, updates)
            c0,c1,c2... = GroupSlice(c)
        }
    Into:
        Subgraph(data, indices, updates) --> c, c0,c1,c2...
        {
            data0,data1,data2,... = GroupSlice(data)
            indices0,indices1,indices2,... = GroupSlice(indices)
            updates0,updates1,updates2,... = GroupSlice(updates)

            c0 = ScatterElements(data0, indices0, updates0)
            c1 = ScatterElements(data1, indices1, updates1)
            c2 = ScatterElements(data2, indices2, updates2)

            # if possible
            c = concat(c0,c1,c2,...) 
            # or c = ScatterElements(data, indices, updates)

        }

    Also, encodings are updated
    '''

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.op_node: ir.Node | None = None

    def match(self, node: ir.Node) -> bool:
        gslice_node = node
        if not is_reorderable_group_slice(gslice_node):
            return False
        # check for mypy, definitely true
        assert gslice_node.inputs[0] is not None
        op_node = gslice_node.inputs[0].producer()
        assert op_node is not None  # check for mypy, definitely true
        if op_node.op_type != "ScatterElements":
            return False

        check_static_shape_of_node_io(op_node)
        assert op_node.inputs[0] is not None  # check for mypy, definitely true
        # check for mypy, definitely true
        assert op_node.inputs[0].shape is not None

        self.op_node = op_node

        if "axis" in self.op_node.attributes:
            scatter_axis = convert_attr_to_py(self.op_node.attributes["axis"],
                                                    "as_int")
        else:
            scatter_axis = 0
        if scatter_axis < 0:
            scatter_axis += len(op_node.inputs[0].shape)
        gslice_attrs = get_gslice_attrs(gslice_node)

        if gslice_attrs.axis == scatter_axis:
            return False
        return True

    def rewrite(self, node: ir.Node) -> bool:
        gslice_node = node
        assert self.op_node is not None  # check for mypy, definitely true

        # accroding to the definition of ScatterElements
        # data/indices/updates should have same rank
        # indices/updates should have same shape

        # for gslice_axis != scatter_axis (already checked),
        # to reoder gslice, we should gslice all the inputs with the same attributes
        data_groupslice_attrs = get_gslice_attrs(gslice_node)
        indices_groupslice_attrs = get_gslice_attrs(gslice_node)
        updates_groupslice_attrs = get_gslice_attrs(gslice_node)

        _ = self.rewrite_based_on_gslice_attrs(
            self.op_node, [gslice_node],
            inputs_gslice_attrs=[data_groupslice_attrs,
                                 indices_groupslice_attrs, updates_groupslice_attrs]
        )

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), self.op_node.name)
        self.op_node = None
        return True
