# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for reordering 
(LayerNormalization->GroupSlice) -> (GroupSlice->LayerNormalization)
"""

import copy

from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.base_rewriter import M2sBaseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.utils import (
    BroadcastHelper,
    FullGroupSliceAttrs,
    GroupSliceAttrs,
    get_gslice_attrs,
    get_value_numeric_shape,
    is_reorderable_group_slice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    check_static_shape_of_node_io,
    logger,
)


class M2sReorderMatmulGroupslice(M2sBaseRewriter):
    '''
    Transform subgraph:
        Subgraph(in_a, in_b) --> c,c0,c1,c2,...
        {
            c = Matmul(in_a, in_b)
            c0,c1,c2... = GroupSlice(c)

        }
    to:
        Subgraph(in_a, in_b) --> c,c0,c1,c2,...
        {

            a0,a1,a2... = GroupSlice(in_a) # for some case, no groupslice is required
            b0,b1,b2... = GroupSlice(in_b)

            c0 = Matmul(a0, b0)
            c1 = Matmul(a1, b1)
            c2 = Matmul(a2, b2)

            # if possible
            c = concat(c0,c1,c2,...) 
            # or c = Matmul(in_a, in_b)
        }

    Also, encodings are updated
    '''

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.matmul_node: ir.Node | None = None
        self.rank: int | None = None
        self.gslice_attrs: GroupSliceAttrs | None = None

    def match(self, node: ir.Node) -> bool:
        if not is_reorderable_group_slice(node):
            return False
        assert node.inputs[0] is not None  # check for mypy, definitely true
        matmul_node = node.inputs[0].producer()
        assert matmul_node is not None  # check for mypy, definitely true
        if matmul_node.op_type != "MatMul":
            return False

        check_static_shape_of_node_io(matmul_node)
        # check for mypy, definitely true
        assert matmul_node.outputs[0] is not None
        # check for mypy, definitely true
        assert matmul_node.outputs[0].shape is not None

        self.matmul_node = matmul_node
        self.rank = matmul_node.outputs[0].shape.rank()
        self.gslice_attrs = get_gslice_attrs(node)

        if self.gslice_attrs.axis == self.rank - 1:
            # slice on last dim, so it can always be sliced.
            return True
        if self.gslice_attrs.axis == self.rank - 2:
            # slice on second last dim, so it can always be sliced.
            return True
        if self.gslice_attrs.axis < self.rank - 2:
            # slice on broadcastable dim, so we need to handle broadcast.
            return True

        return False

    def rewrite(self, node: ir.Node) -> bool:
        assert self.matmul_node is not None  # check for mypy, definitely true
        assert self.gslice_attrs is not None
        assert self.rank is not None

        gslice_node = node
        in_a, in_b = self.matmul_node.inputs

        if self.gslice_attrs.axis == self.rank - 1:
            # slice on last dim, so it can always be sliced.
            a_groupslice_attrs_full = FullGroupSliceAttrs(
                self.gslice_attrs.num_outputs(),
                head_slice_ids=self.gslice_attrs.head_slice_ids,
                batch_slice_ids=self.gslice_attrs.batch_slice_ids,
            )
            b_groupslice_attrs = copy.deepcopy(self.gslice_attrs)
            b_groupslice_attrs.axis = len(get_value_numeric_shape(in_b)) - 1

            _ = self.rewrite_based_on_gslice_attrs(
                self.matmul_node, [gslice_node],
                inputs_gslice_attrs=[a_groupslice_attrs_full, b_groupslice_attrs]
            )
        elif self.gslice_attrs.axis == self.rank - 2:
            # slice on second last dim, so it can always be sliced.
            b_groupslice_attrs_full = FullGroupSliceAttrs(
                self.gslice_attrs.num_outputs(),
                head_slice_ids=self.gslice_attrs.head_slice_ids,
                batch_slice_ids=self.gslice_attrs.batch_slice_ids,
            )
            a_groupslice_attrs = copy.deepcopy(self.gslice_attrs)
            a_groupslice_attrs.axis = len(get_value_numeric_shape(in_a)) - 2

            _ = self.rewrite_based_on_gslice_attrs(
                self.matmul_node, [gslice_node],
                inputs_gslice_attrs=[a_groupslice_attrs, b_groupslice_attrs_full]
            )
        elif self.gslice_attrs.axis < self.rank - 2:

            bc_helper = BroadcastHelper(
                get_value_numeric_shape(in_a), get_value_numeric_shape(in_b),
                get_value_numeric_shape(self.matmul_node.outputs[0]),
                ignore_last_dim_num=2
            )

            a_groupslice_attrs = bc_helper.get_input_group_attrs(
                0, self.gslice_attrs)
            b_groupslice_attrs = bc_helper.get_input_group_attrs(
                1, self.gslice_attrs)

            _ = self.rewrite_based_on_gslice_attrs(
                self.matmul_node, [gslice_node],
                inputs_gslice_attrs=[a_groupslice_attrs, b_groupslice_attrs]
            )

        else:
            raise ValueError()

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), self.matmul_node.name)
        self.matmul_node = None
        return True
