# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for reordering 
(LayerNormalization->GroupSlice) -> (GroupSlice->LayerNormalization)
"""



from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.base_rewriter import M2sBaseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.utils import (
    FullGroupSliceAttrs,
    get_gslice_attrs,
    is_reorderable_group_slice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    convert_attr_to_py,
    get_attribute_with_default,
    get_constant_np,
    is_constant,
    logger,
)


class M2sReorderUnaryReduceGroupsliceIfAxesIsInput(M2sBaseRewriter):
    '''
    Reorder subgraph:
        Subgraph(in_a) --> b,b0,b1,b2
        {
            b = ReduceOp(in_a) # such as softmax
            b0,b1,b2... = GroupSlice(b)
        }
    Into:
        Subgraph(in_a) --> b,b0,b1,b2
        {
            a0,a1,a2... = GroupSlice(in_a)

            b0 = ReduceOp(a0)
            b1 = ReduceOp(a1)
            b2 = ReduceOp(a2)
            ...

            # if possible
            b = concat(b0,b1,b2,...) 
            # or b = ReduceOp(in_a)
        }

    Also, encodings are updated
    '''

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.support_op_types = {
            "ReduceMin": {"min_op_ver": 18},
            "ReduceMean": {"min_op_ver": 18},
            "ReduceMax": {"min_op_ver": 18},
            "ReduceProd": {"min_op_ver": 18},
            "ReduceSum": {"min_op_ver": 13},
        }
        self.reduce_node: ir.Node | None = None

    def match(self, node: ir.Node) -> bool: # pylint: disable=R0911
        gslice_node = node
        if not is_reorderable_group_slice(gslice_node):
            return False

        # check for mypy, definitely true
        assert gslice_node.inputs[0] is not None
        reduce_node = gslice_node.inputs[0].producer()
        assert reduce_node is not None  # check for mypy, definitely true
        if reduce_node.op_type not in self.support_op_types:
            return False

        if reduce_node.domain not in ["", "ai.onnx", "main"]:
            return False
        op_version = self.graph.opset_imports[reduce_node.domain]
        if op_version < self.support_op_types[reduce_node.op_type]["min_op_ver"]:
            return False
        if "max_op_ver" in self.support_op_types[reduce_node.op_type]:
            if op_version > self.support_op_types[reduce_node.op_type]["max_op_ver"]:
                return False

        gslice_axis = convert_attr_to_py(gslice_node.attributes["axis"],
                                               "as_int")
        noop_with_empty_axes = get_attribute_with_default(
            reduce_node, "noop_with_empty_axes", 0)
        if len(reduce_node.inputs) < 2:
            # no axes
            if not noop_with_empty_axes:
                # reduce on all axes, can't reorder groupslice
                return False

            return True  # act as en Identity

        if not is_constant(reduce_node.inputs[1]):
            return False

        reduce_axes = get_constant_np(reduce_node.inputs[1]).tolist()
        if gslice_axis in reduce_axes:
            return False

        keepdims = get_attribute_with_default(reduce_node, "keepdims", 1)
        if not keepdims:
            # TODO, add support for keepdims == 0
            logger.warning("current we only support to handle %s with keepdims=1, skip at %s",
                            reduce_node.op_type, reduce_node.name)
            return False

        self.reduce_node = reduce_node
        return True

    def rewrite(self, node: ir.Node) -> bool:
        assert self.reduce_node is not None  # check for mypy, definitely true
        gslice_node = node

        # when reduce_axis != gslice_axis
        # input gslice attrs should be the same on output gslice
        a_groupslice_attrs = get_gslice_attrs(gslice_node)
        axes_groupslice_attrs = FullGroupSliceAttrs(a_groupslice_attrs.num_outputs(),
                                                    a_groupslice_attrs.head_slice_ids,
                                                    a_groupslice_attrs.batch_slice_ids)
        _ = self.rewrite_based_on_gslice_attrs(
            self.reduce_node, [gslice_node],
            inputs_gslice_attrs=[a_groupslice_attrs, axes_groupslice_attrs]
        )

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), self.reduce_node.name)

        self.reduce_node = None
        return True


class M2sReorderUnaryReduceGroupsliceIfAxesIsAttr(M2sBaseRewriter):
    '''
    Reorder subgraph:
        Subgraph(in_a) --> b,b0,b1,b2
        {
            b = ReduceOp(in_a) # such as softmax
            b0,b1,b2... = GroupSlice(b)
        }
    Into:
        Subgraph(in_a) --> b,b0,b1,b2
        {
            a0,a1,a2... = GroupSlice(in_a)

            b0 = ReduceOp(a0)
            b1 = ReduceOp(a1)
            b2 = ReduceOp(a2)
            ...

            # if possible
            b = concat(b0,b1,b2,...) 
            # or b = ReduceOp(in_a)
        }

    Also, encodings are updated
    '''

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.support_op_types = ["ReduceMin", "ReduceMean",
                                 "ReduceMax", "ReduceProd",
                                 "ReduceSum"
                                 ]
        self.reduce_node: ir.Node | None = None

    def match(self, node: ir.Node) -> bool:
        gslice_node = node
        if not is_reorderable_group_slice(gslice_node):
            return False

        # check for mypy, definitely true
        assert gslice_node.inputs[0] is not None
        reduce_node = gslice_node.inputs[0].producer()
        assert reduce_node is not None  # check for mypy, definitely true
        if reduce_node.op_type not in self.support_op_types:
            return False
        gslice_axis = convert_attr_to_py(gslice_node.attributes["axis"],
                                               "as_int")
        reduce_axes = convert_attr_to_py(reduce_node.attributes["axes"],
                                               "as_ints")
        if gslice_axis in reduce_axes:
            return False
        keepdims = get_attribute_with_default(reduce_node, "keepdims", 1)
        if keepdims != 1:
            # TODO, add support for keepdims == 0
            logger.warning("current we only support to handle %s with keepdims=1, skip at %s",
                            reduce_node.op_type, reduce_node.name)
            return False
        self.reduce_node = reduce_node
        return True

    def rewrite(self, node: ir.Node) -> bool:
        assert self.reduce_node is not None  # check for mypy, definitely true
        gslice_node = node

        # when reduce_axis != gslice_axis
        # input gslice attrs should be the same on output gslice
        a_groupslice_attrs = get_gslice_attrs(gslice_node)

        _ = self.rewrite_based_on_gslice_attrs(
            self.reduce_node, [gslice_node],
            inputs_gslice_attrs=[a_groupslice_attrs]
        )

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), self.reduce_node.name)
        self.reduce_node = None
        return True


M2sReorderUnaryReduceGroupslice_Passes = [
    # the order of passes should be kept
    M2sReorderUnaryReduceGroupsliceIfAxesIsInput,  # checked with opset condition
    M2sReorderUnaryReduceGroupsliceIfAxesIsAttr    # default
]
