# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for reordering 
(Expand->GroupSlice) -> (GroupSlice->Expand)
"""
import copy
from typing import Dict, List, Tuple

from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.base_rewriter import M2sBaseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.utils import (
    BroadcastHelper,
    FullGroupSliceAttrs,
    GroupSliceAttrs,
    get_gslice_attrs,
    is_reorderable_group_slice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.ir_extra_info import (
    VariableExtraInfo,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    check_static_shape_of_node_io,
    get_value_numeric_shape,
    logger,
    make_initializer,
)


class M2sReorderExpandGroupslice(M2sBaseRewriter):
    '''
    Transform subgraph:
        Subgraph(in_a, in_b) --> c, c0,c1,c2...
        {
            c = Expand(in_a, in_b)
            c0,c1,c2... = GroupSlice(c)
        }
    Into:
        Subgraph(in_a, in_b) --> c, c0,c1,c2...
        {
            a0,a1,a2,... = GroupSlice(in_a)

            c0 = Expand(in_a0, in_b)
            c1 = Expand(in_a1, in_b)
            c2 = Expand(in_a2, in_b)

            # if possible
            c = concat(c0,c1,c2,...) 
            # or c = Expand(in_a, in_b)

        }

    Broadcasting will be automatically updated
    Also, encodings are updated
    '''

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.op_node: ir.Node | None = None
        self.out_shape: List[int] | None = None
        self.output_gslice_attrs: GroupSliceAttrs | None = None
        self.constant_cache: Dict[Tuple[int], ir.Value] = {}

    def match(self, node: ir.Node) -> bool:
        gslice_node = node
        if not is_reorderable_group_slice(gslice_node):
            return False

        # check for mypy, definitely true
        assert gslice_node.inputs[0] is not None
        op_node = gslice_node.inputs[0].producer()
        assert op_node is not None  # check for mypy, definitely true

        if op_node.op_type != "Expand":
            return False

        check_static_shape_of_node_io(op_node)
        self.op_node = op_node
        return True

    def create_mini_pattern(self,              # pylint: disable=R0913,R0917
                            origin_op: ir.Node,
                            mini_inputs: List[ir.Value|None],
                            head_slice_id,
                            batch_slice_id,
                            slice_i
                            ) -> List[ir.Value|None]:
        assert self.out_shape is not None  # check for mypy
        assert self.output_gslice_attrs is not None  # check for mypy

        mini_out_shape_list = list(self.out_shape)[:]
        mini_out_shape_list[self.output_gslice_attrs.axis] = \
            self.output_gslice_attrs.ends[slice_i] \
            - self.output_gslice_attrs.starts[slice_i]
        mini_out_shape: Tuple[int] = tuple(mini_out_shape_list)  # type: ignore

        if mini_out_shape == tuple(get_value_numeric_shape(mini_inputs[0])):
            # no need to expand
            return [mini_inputs[0]]

        # in_b is the shape (integer), so it cannot have encodings
        # so we can safely create new b
        if mini_out_shape in self.constant_cache:
            in_b = self.constant_cache[mini_out_shape]
        else:
            assert origin_op.name is not None  # check for mypy, definitely true
            in_b = make_initializer(self.graph,
                                    self.graph.meta["extra_info"].get_unique_name_with_suffix(
                                        origin_op.name, ".sha.expand_constant"
                                    ),
                                    mini_out_shape,
                                    )
            self.constant_cache[mini_out_shape] = in_b
            in_b.meta["extra_info"] = VariableExtraInfo()

        mini_inputs[1] = in_b

        # create mini op
        mini_op = ir.Node(domain="", op_type=origin_op.op_type,
                          inputs=mini_inputs,
                          num_outputs=len(origin_op.outputs),
                          name=self.get_mini_op_name(
                              origin_op.name, head_slice_id, batch_slice_id
                          ))

        self.graph.insert_before(origin_op, mini_op)
        mini_op.attributes.update(copy.deepcopy(origin_op.attributes))
        return list(mini_op.outputs)

    def rewrite(self, node: ir.Node) -> bool:
        gslice_node = node
        assert self.op_node is not None  # check for mypy, definitely true

        in_shape = get_value_numeric_shape(self.op_node.inputs[0])
        out_shape = get_value_numeric_shape(self.op_node.outputs[0])

        output_gslice_attrs = get_gslice_attrs(gslice_node)

        self.out_shape = out_shape
        self.output_gslice_attrs = output_gslice_attrs
        self.constant_cache = {}

        bc_helper = BroadcastHelper(
            in_shape, out_shape,
            get_value_numeric_shape(self.op_node.outputs[0]),
        )
        in_a_groupslice_attrs = bc_helper.get_input_group_attrs(
            0, output_gslice_attrs)
        in_b_groupslice_attrs = FullGroupSliceAttrs(output_gslice_attrs.num_outputs(),
                                                    [-1] * output_gslice_attrs.num_outputs(),
                                                    [-1] * output_gslice_attrs.num_outputs())

        _ = self.rewrite_based_on_gslice_attrs(
            self.op_node, [gslice_node],
            inputs_gslice_attrs=[in_a_groupslice_attrs, in_b_groupslice_attrs]
        )

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), self.op_node.name)
        self.op_node = None
        return True
