# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for reordering 
(InstanceNormalization->GroupSlice) -> (GroupSlice->InstanceNormalization)
"""


from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.base_rewriter import M2sBaseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.utils import (
    FullGroupSliceAttrs,
    get_gslice_attrs,
    is_reorderable_group_slice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    convert_attr_to_py,
    has_static_shape_on_value,
    logger,
)


class M2sReorderInstancenormGroupslice(M2sBaseRewriter):
    '''
    Reorder subgraph:
        Subgraph(in_a) --> b,b0,b1,b2
        {
            b = Instancenorm(in_a) # such as Instancenorm
            b0,b1,b2... = GroupSlice(b)
        }
    Into:
        Subgraph(in_a) --> b,b0,b1,b2
        {
            a0,a1,a2... = GroupSlice(in_a)

            b0 = Instancenorm(a0)
            b1 = Instancenorm(a1)
            b2 = Instancenorm(a2)
            ...

            # if possible
            b = concat(b0,b1,b2,...) 
            # or b = Instancenorm(in_a)
        }

    Also, encodings are updated
    '''

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.support_op_types = ["InstanceNormalization"]
        self.op_node :ir.Node|None = None

    def match(self, node: ir.Node) -> bool:
        gslice_node = node
        if not is_reorderable_group_slice(gslice_node):
            return False
        op_node = gslice_node.inputs[0].producer()  # type: ignore
        assert op_node is not None  # check for mypy, definitely true
        op_type = op_node.op_type
        if op_type not in self.support_op_types:
            return False

        assert op_node.inputs[0] is not None # check for mypy, definitely true
        if not has_static_shape_on_value(op_node.inputs[0]):
            return False

        if convert_attr_to_py(
            gslice_node.attributes["axis"], "as_int") not in [0, 1]:
            return False

        self.op_node = op_node
        return True

    def rewrite(self, node: ir.Node) -> bool:
        assert self.op_node is not None  # check for mypy, definitely true
        gslice_node = node

        gslice_axis = convert_attr_to_py(gslice_node.attributes["axis"], "as_int")
        if gslice_axis == 0:
            # slice on batch
            a_groupslice_attrs = get_gslice_attrs(gslice_node)
            other_inputs_gslice_attrs = []
            for _ in self.op_node.inputs[1:]:
                # scale and bias
                other_inputs_gslice_attrs.append(
                    FullGroupSliceAttrs(a_groupslice_attrs.num_outputs(),
                                        a_groupslice_attrs.head_slice_ids[:],
                                        a_groupslice_attrs.batch_slice_ids[:])
                )
            _ = self.rewrite_based_on_gslice_attrs(
                self.op_node, [gslice_node],
                inputs_gslice_attrs=[a_groupslice_attrs] +
                other_inputs_gslice_attrs
            )
        elif gslice_axis == 1:
            # slice on channel
            a_groupslice_attrs = get_gslice_attrs(gslice_node)
            other_inputs_gslice_attrs = []
            for _ in self.op_node.inputs[1:]:
                # scale and bias
                groupslice_attrs = get_gslice_attrs(gslice_node)
                groupslice_attrs.axis = 0
                other_inputs_gslice_attrs.append(groupslice_attrs)

            _ = self.rewrite_based_on_gslice_attrs(
                self.op_node, [gslice_node],
                inputs_gslice_attrs=[a_groupslice_attrs] +
                other_inputs_gslice_attrs
            )
        else:
            assert False, "something wrong"

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), self.op_node.name)
        self.op_node = None
        return True
