# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides a pass to simplify Concat+Transpose to Concat
"""

from typing import List

from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BasePredicateRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.reshape_transpose_analysis import ReshapeTransposeInfoSeq
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    convert_attr_to_py,
    get_value_numeric_shape,
    have_static_shape_on_node_io,
    logger,
    safe_replace_all_uses_with,
)


class SimplifyConcatTransposeRewriter(BasePredicateRewriter):
    """
    A graph rewriter pass that simplify reshape transpose consequences after concat

    simplify pattern
        Concat->Transpose
    To
        Concat

    For example, simplify
        Transpose(Concat(X0,X1,X2, axis=1), perm=[1,0,2,3]),
            where X0,X1,X2 has shape [1,1,128,64]
    To 
        Concat(X0,X1,X2, axis=0)

    """

    def __init__(self, graph):
        super().__init__(graph)
        self.concat_node: ir.Node | None = None
        self.transpose_node: ir.Node | None = None
        self.transpose_perm: List[int] | None = None

    def match(self, node):  # pylint: disable=[too-many-return-statements]
        if node.op_type != "Concat":
            return False
        concat_node = node
        concat_out = node.outputs[0]

        # transpose node is the only one consumer of concat
        uses = concat_out.uses()
        if len(uses) != 1:
            return False

        if uses[0].node.op_type != "Transpose":
            return False
        transpose_node = uses[0].node

        if not have_static_shape_on_node_io(concat_node):
            return False
        if not have_static_shape_on_node_io(transpose_node):
            return False

        transpose_perm = transpose_node.attributes["perm"].as_ints()

        # case1, transpose can be elminiated if concat axis is permuted
        for v in concat_node.inputs:
            pre_info_seq = ReshapeTransposeInfoSeq(
                [ReshapeTransposeInfoSeq.TransposeNodeInfo(transpose_perm)],
                v.shape.numpy(),
            )
            pre_info_seq = pre_info_seq.simplify_seq()
            if len(pre_info_seq.seq) != 0:
                return False

        self.concat_node = concat_node
        self.transpose_node = transpose_node
        self.transpose_perm = transpose_perm

        return True

    def rewrite(self, node):
        # check for mypy
        assert self.concat_node is not None
        assert self.transpose_perm is not None
        assert self.transpose_node is not None

        concat_axis = convert_attr_to_py(self.concat_node.attributes["axis"], "as_int")
        if concat_axis < 0:
            concat_axis += len(get_value_numeric_shape(self.concat_node.inputs[0]))
        new_concat_axis = self.transpose_perm.index(concat_axis)
        new_concat_node = ir.Node("", "Concat",
                                  self.concat_node.inputs,
                                  attributes=[ir.AttrInt64(
                                      "axis", new_concat_axis)],
                                  name=self.graph.meta["extra_info"].get_unique_name_with_suffix(
                                      self.concat_node.name, "/simplified"))
        new_concat_node.outputs[0].name = self.graph.meta["extra_info"].get_unique_name_with_suffix(
            self.concat_node.outputs[0].name, "/simplified")
        self.mark_value_as_copy(
            self.transpose_node.outputs[0], new_concat_node.outputs[0])
        self.graph.insert_after(self.transpose_node, new_concat_node)
        safe_replace_all_uses_with(
            self.graph, self.transpose_node.outputs[0], new_concat_node.outputs[0]
        )

        logger.debug(
            "applied pass %s on '%s'",
            self.get_curr_pass_name(), self.concat_node.name
        )
        return True
