# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for reordering 
(ScatterElements->GroupSlice) -> (GroupSlice->ScatterElements)
"""

from typing import List

from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.base_rewriter import M2sBaseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.utils import (
    FullGroupSliceAttrs,
    get_gslice_attrs,
    is_reorderable_group_slice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    check_static_shape_of_node_io,
    convert_attr_to_py,
    get_constant_np,
    is_constant,
    logger,
)


class M2sReorderSliceGroupslice(M2sBaseRewriter):
    '''
    Transform subgraph:
        Subgraph(in_a) --> b,b0,b1,b2
        {
            b = slice(in_a)
            b0,b1,b2... = GroupSlice(b)
        }
    Into:
        Subgraph(in_a) --> b,b0,b1,b2
        {

            a0,a1,a2... = GroupSlice(in_a)

            b0 = slice(a0)
            b1 = slice(a1)
            b2 = slice(a2)
            ...

            # if possible
            b = concat(b0,b1,b2,...) 
            # or b = slice(in_a)
        }

    Also, encodings are updated
    '''

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.op_node: ir.Node | None = None
        self.slice_axes: List[int] | None = None
        self.gslice_axis: int | None = None

    def match(self, node: ir.Node) -> bool:
        gslice_node = node
        if not is_reorderable_group_slice(gslice_node):
            return False
        # check for mypy, definitely true
        assert gslice_node.inputs[0] is not None
        op_node = gslice_node.inputs[0].producer()
        assert op_node is not None  # check for mypy, definitely true
        if op_node.op_type != "Slice":
            return False

        check_static_shape_of_node_io(op_node)

        assert op_node.inputs[0] is not None  # check for mypy, definitely true
        # check for mypy, definitely true
        assert op_node.inputs[0].shape is not None

        if len(op_node.inputs) >= 4:
            slice_axes_v = op_node.inputs[3]
            slice_axes = get_constant_np(slice_axes_v).tolist()
        else:
            slice_axes = [0]

        # normalize slice_axes
        for i, _ in enumerate(slice_axes):
            if slice_axes[i] < 0:
                slice_axes[i] += len(op_node.inputs[0].shape)

        gslice_axis = convert_attr_to_py(gslice_node.attributes["axis"],
                                               "as_int")

        self.op_node = op_node
        self.slice_axes = slice_axes
        self.gslice_axis = gslice_axis
        return True

    def create_mini_pattern(self,        # pylint: disable=R0913,R0917
                            origin_op: ir.Node,
                            mini_inputs: List[ir.Value|None],
                            head_slice_id,
                            batch_slice_id,
                            slice_i
                            ) -> List[ir.Value|None]:
        assert self.slice_axes is not None # check for mypy
        if tuple(self.slice_axes) == (self.gslice_axis,):
            # no need to slice anymore
            return [mini_inputs[0]]

        return super().create_mini_pattern(origin_op, mini_inputs,
                                           head_slice_id, batch_slice_id, slice_i)

    def rewrite(self, node: ir.Node) -> bool: # pylint: disable=R0911,R0912
        gslice_node = node
        assert self.op_node is not None  # check for mypy, definitely true
        assert self.gslice_axis is not None
        assert self.slice_axes is not None

        if self.gslice_axis not in self.slice_axes:
            out_gslice_attrs = get_gslice_attrs(gslice_node)
            inputs_gslice_attrs = [
                out_gslice_attrs
            ]
            for _ in self.op_node.inputs[1:]:
                # full slices for other inputs (starts, ends, axes, steps)
                inputs_gslice_attrs.append(
                    FullGroupSliceAttrs(out_gslice_attrs.num_outputs(),
                                        head_slice_ids=[-1] *
                                        out_gslice_attrs.num_outputs(),
                                        batch_slice_ids=[-1] *
                                        out_gslice_attrs.num_outputs()
                                        )
                )

            _ = self.rewrite_based_on_gslice_attrs(
                self.op_node, [gslice_node],
                inputs_gslice_attrs=inputs_gslice_attrs
            )
        elif self.gslice_axis in self.slice_axes and len(self.slice_axes) > 1:
            # multi-slices
            # to simplify the code, let's firstly transform it into a sequnce of
            # - a slice that slice_axes don't include self.gslice_axis
            # - a slice that slice_axes has only self.gslice_axis
            # complexe and rare, support it in the future if required
            return False
        elif tuple(self.slice_axes) == (self.gslice_axis,):
            # self.gslice_axis in self.slice_axes
            # for example,
            #   Slice(start=0, end=384, axes=2)
            #   Groupslice(starts=[0,24],ends=[24,48], axes=2)

            if not is_constant(self.op_node.inputs[1]):
                return False
            if not is_constant(self.op_node.inputs[2]):
                return False
            if len(self.op_node.inputs) >= 5:
                steps_v = self.op_node.inputs[4]
                if not is_constant(steps_v):
                    return False
                steps_cst = get_constant_np(steps_v)[0]
                if steps_cst != 1:
                    return False
            else:
                steps_cst = 1

            slice_start_cst = int(get_constant_np(self.op_node.inputs[1])[0])
            slice_end_cst = int(get_constant_np(self.op_node.inputs[2])[0])

            out_gslice_attrs = get_gslice_attrs(gslice_node)
            data_gslice_attrs = out_gslice_attrs
            data_gslice_attrs.starts = [
                x + slice_start_cst for x in data_gslice_attrs.starts]
            data_gslice_attrs.ends = [
                x + slice_start_cst for x in data_gslice_attrs.ends]

            # check end
            for curr_end in data_gslice_attrs.ends:
                if curr_end > slice_end_cst:
                    return False

            inputs_gslice_attrs = [
                data_gslice_attrs
            ]
            for _ in self.op_node.inputs[1:]:
                # full slices for other inputs (starts, ends, axes, steps)
                # not used actually (see self.create_mini_pattern)
                inputs_gslice_attrs.append(
                    FullGroupSliceAttrs(out_gslice_attrs.num_outputs(),
                                        head_slice_ids=[-1] *
                                        out_gslice_attrs.num_outputs(),
                                        batch_slice_ids=[-1] *
                                        out_gslice_attrs.num_outputs()
                                        )
                )

            _ = self.rewrite_based_on_gslice_attrs(
                self.op_node, [gslice_node],
                inputs_gslice_attrs=inputs_gslice_attrs
            )
        else:
            # unknow situation
            return False

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), self.op_node.name)

        self.op_node = None
        self.gslice_axis = None
        self.slice_axes = None
        return True
