# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for replacing 
(Split->GroupSlice) -> (Slice->GroupSlice)
so that we don't need to re-implement Split-GroupSlice reordering,
just reuse Slice->GroupSlice reordering
"""


from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.base_rewriter import M2sBaseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.utils import (
    is_reorderable_group_slice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    check_static_shape_of_node_io,
    convert_attr_to_py,
    get_value_numeric_shape,
    logger,
    make_initializer,
    safe_replace_all_uses_with,
)


class M2sReplaceSplitGroupslice2SliceGroupslice(M2sBaseRewriter):
    '''
    Transform subgraph:
        Subgraph(in_a) --> c, c00,c01,c02,c10,c11,c12,c20,c21,c22...
        {
            b1,b2,b3 = Split(in_a)
            c00,c01,c02... = GroupSlice(b1)
            c10,c11,c12... = GroupSlice(b2)
            c20,c21,c22... = GroupSlice(b3)
        }
    Into:
        Subgraph(in_a) --> c, c00,c01,c02,c10,c11,c12,c20,c21,c22...
        {
            b1 = Slice(in_a)
            b2 = Slice(in_a)
            b3 = Slice(in_a)
            c00,c01,c02... = GroupSlice(b1)
            c10,c11,c12... = GroupSlice(b2)
            c20,c21,c22... = GroupSlice(b3)
        }
    Do not handle groupslice, it will be handled by reorder_slice_gslice pass
    Also, encodings are updated
    '''

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.op_node: ir.Node | None = None

    def match(self, node: ir.Node) -> bool:
        gslice_node = node
        if not is_reorderable_group_slice(gslice_node):
            return False
        # check for mypy, definitely true
        assert gslice_node.inputs[0] is not None
        op_node = gslice_node.inputs[0].producer()
        assert op_node is not None  # check for mypy, definitely true
        if op_node.op_type != "Split":
            return False

        check_static_shape_of_node_io(op_node)
        self.op_node = op_node
        return True

    def rewrite(self, node: ir.Node) -> bool:
        gslice_node = node
        assert self.op_node is not None  # check for mypy, definitely true
        assert self.op_node.name is not None  # check for mypy, definitely true

        split_axis = convert_attr_to_py(self.op_node.attributes["axis"],
                                              "as_int")
        outputs_shapes = [get_value_numeric_shape(
            v) for v in self.op_node.outputs]
        split_sizes = [s[split_axis] for s in outputs_shapes]

        accum_size = 0
        for v_i, _ in enumerate(self.op_node.outputs):
            slice_node_name = self.graph.meta["extra_info"].get_unique_name_with_suffix(
                self.op_node.name, "/slice_" + str(v_i))
            v_i_start = make_initializer(self.graph,
                                         self.graph.meta["extra_info"].get_unique_name_with_suffix(
                                             slice_node_name, ".start"),
                                         [accum_size]
                                         )
            v_i_end = make_initializer(self.graph,
                                       self.graph.meta["extra_info"].get_unique_name_with_suffix(
                                           slice_node_name, ".end"),
                                       [accum_size + split_sizes[v_i]]
                                       )
            v_i_axes = make_initializer(self.graph,
                                        self.graph.meta["extra_info"].get_unique_name_with_suffix(
                                            slice_node_name, ".axes"),
                                        [split_axis]
                                        )
            slice_node = ir.Node(
                "", op_type="Slice",
                inputs=[self.op_node.inputs[0], v_i_start, v_i_end, v_i_axes],
                name=slice_node_name
            )
            slice_node.outputs[0].name = self.graph.meta["extra_info"].get_unique_name(
                self.op_node.outputs[v_i].name)
            self.mark_value_as_copy(
                self.op_node.outputs[v_i], slice_node.outputs[0])

            accum_size += split_sizes[v_i]
            self.graph.insert_after(self.op_node, slice_node)
            safe_replace_all_uses_with(self.graph,
                                       self.op_node.outputs[v_i], slice_node.outputs[0])

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), gslice_node.name)

        return True
