# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for reordering 
(Concat->GroupSlice) -> (GroupSlice->Concat)
"""
import itertools
from typing import List

from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.base_rewriter import M2sBaseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.utils import (
    GroupSliceAttrs,
    get_gslice_attrs,
    is_reorderable_group_slice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    check_static_shape_of_node_io,
    convert_attr_to_py,
    get_value_numeric_shape,
    logger,
)


class M2sReorderConcatGroupslice(M2sBaseRewriter):
    '''
    Transform subgraph:
        Subgraph(in_a_1, in_a_2, in_a_3, ...) --> b,b0,b1,b2
        {
            b = Concat(in_a_1, in_a_2, in_a_3, ...)
            b0,b1,b2... = GroupSlice(b)
        }
    Into:
        Subgraph(in_a_1, in_a_2, in_a_3, ...) --> b,b0,b1,b2
        {

            in_a_1.0, in_a_1.1, in_a_1.1 ... = GroupSlice(in_a_1)
            in_a_2.0, in_a_2.1, in_a_2.2 ... = GroupSlice(in_a_2)
            in_a_3.0, in_a_3.1, in_a_3.2 ... = GroupSlice(in_a_3)
            ...

            b0 = Concat(in_a_1.0, in_a_2.0, in_a_3.0, ...)
            b1 = Concat(in_a_1.1, in_a_2.1, in_a_3.1, ...)
            b2 = Concat(in_a_1.2, in_a_2.2, in_a_3.2, ...)
            ...

            # if possible
            b = Concat(b0,b1,b2,...) 
            # or b = Concat(in_a_1, in_a_2, in_a_3, ...)

        }

    Also, encodings are updated
    '''

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.op_node: ir.Node | None = None
        self.concat_axis: int | None = None

    def match(self, node: ir.Node) -> bool:
        gslice_node = node
        if not is_reorderable_group_slice(gslice_node):
            return False
        # check for mypy, definitely true
        assert gslice_node.inputs[0] is not None
        op_node = gslice_node.inputs[0].producer()
        assert op_node is not None  # check for mypy, definitely true

        if op_node.op_type != "Concat":
            return False

        check_static_shape_of_node_io(op_node)

        concat_axis = convert_attr_to_py(op_node.attributes["axis"],
                                               "as_int")

        self.concat_axis = concat_axis
        self.op_node = op_node
        return True

    def create_mini_pattern(self,      # pylint: disable=R0913,R0917
                            origin_op: ir.Node,
                            mini_inputs: List[ir.Value|None],
                            head_slice_id,
                            batch_slice_id,
                            slice_i
                            ) -> List[ir.Value|None]:
        # filtere mini inputs
        mini_inputs = [x for x in mini_inputs if x is not None]
        if len(mini_inputs) > 1:
            return super().create_mini_pattern(origin_op, mini_inputs, head_slice_id, batch_slice_id, slice_i)
        if len(mini_inputs) == 1:
            # no need to concat
            return mini_inputs

        assert False, "should not happen"

    def rewrite(self, node: ir.Node) -> bool:
        gslice_node = node
        assert self.op_node is not None  # check for mypy, definitely true

        out_gslice_attrs = get_gslice_attrs(gslice_node)

        if self.concat_axis != convert_attr_to_py(
                gslice_node.attributes["axis"], "as_int"):

            _ = self.rewrite_based_on_gslice_attrs(
                self.op_node, [gslice_node],
                inputs_gslice_attrs=[
                    out_gslice_attrs for x in self.op_node.inputs]
            )

        else:
            # concat axis is the same as gslice axis
            # we should handle carefully for this case
            concat_in_dim_list = [get_value_numeric_shape(
                x)[self.concat_axis] for x in self.op_node.inputs]

            concat_in_dim_accum = [0] + \
                list(itertools.accumulate(concat_in_dim_list))

            inputs_gslice_attrs = [GroupSliceAttrs(
                axis=self.concat_axis) for x in range(len(self.op_node.inputs))]

            for slice_i in range(out_gslice_attrs.num_outputs()):
                out_start = out_gslice_attrs.starts[slice_i]
                out_end = out_gslice_attrs.ends[slice_i]

                in_i_start = find_largest_element_smaller_than(
                    concat_in_dim_accum, out_start)
                in_i_start = max(0, in_i_start)
                in_i_end = find_smallest_element_larger_than(
                    concat_in_dim_accum, out_end)

                for input_i in range(len(self.op_node.inputs)):
                    if in_i_start <= input_i < in_i_end:
                        inputs_gslice_attrs[input_i].starts.append(
                            max(0, out_start - concat_in_dim_accum[input_i]))
                        inputs_gslice_attrs[input_i].ends.append(
                            min(out_end, concat_in_dim_accum[input_i+1]) - concat_in_dim_accum[input_i])
                        inputs_gslice_attrs[input_i].head_slice_ids.append(
                            out_gslice_attrs.head_slice_ids[slice_i])
                        inputs_gslice_attrs[input_i].batch_slice_ids.append(
                            out_gslice_attrs.batch_slice_ids[slice_i])
                    else:
                        # empty
                        inputs_gslice_attrs[input_i].starts.append(0)
                        inputs_gslice_attrs[input_i].ends.append(0)
                        inputs_gslice_attrs[input_i].head_slice_ids.append(
                            out_gslice_attrs.head_slice_ids[slice_i])
                        inputs_gslice_attrs[input_i].batch_slice_ids.append(
                            out_gslice_attrs.batch_slice_ids[slice_i])

            _ = self.rewrite_based_on_gslice_attrs(
                self.op_node, [gslice_node],
                inputs_gslice_attrs=list(inputs_gslice_attrs) # list() to make mypy happy
            )

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), self.op_node.name)

        self.op_node = None
        return True


def find_smallest_element_larger_than(ascending_list, value):
    '''
    Helper function
    find smalles element in the ascending_list that larger than value
    '''
    for i, x in enumerate(ascending_list):
        if x >= value:
            return i
    return len(ascending_list)


def find_largest_element_smaller_than(ascending_list, value):
    '''
    Helper function
    find largest element in the ascending_list that smaller than value
    '''
    for i in range(len(ascending_list)-1, -1, -1):
        x = ascending_list[i]
        if x <= value:
            return i
    return -1
