# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for merge a sequence of transpose ops 
into a single transpose op.
"""


from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BasePredicateRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.ir_extra_info import VariableExtraInfo
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    get_value_numeric_shape,
    have_static_shape_on_node_io,
    logger,
    safe_replace_all_uses_with,
)


class MergeSequenceTransposeOps(BasePredicateRewriter):
    '''
    Transform subgraph:
        Subgraph(in_a) --> c
        {
            in_a1 = Transpose(in_a)
            in_a2 = Transpose(in_a2)
            ...
            c = Transpose(in_aX)
        }
    Into:
        Subgraph(in_a) --> c
        {
            c = Transpose(in_a)
        }
    '''
    def __init__(self, graph):
        super().__init__(graph)
        self.transpose_seq = []
        self.v_extra_info = VariableExtraInfo()

    def match(self, node: ir.Node) -> bool:
        if node.op_type != "Transpose":
            return False
        if not have_static_shape_on_node_io(node):
            return False
        self.transpose_seq = [node]
        self.v_extra_info = node.outputs[0].meta["extra_info"]
        # find top-down
        curr_node = node
        while True:
            uses = list(curr_node.outputs[0].uses())
            if len(uses) > 1 or len(uses) == 0:
                break
            if uses[0].node.op_type != "Transpose":
                break
            curr_node = uses[0].node

            if self.v_extra_info.defined_encodings() and \
                    curr_node.outputs[0].meta["extra_info"].defined_encodings():
                if self.v_extra_info != curr_node.outputs[0].meta["extra_info"]:
                    break
            if not self.v_extra_info.defined_encodings() and \
                    curr_node.outputs[0].meta["extra_info"].defined_encodings():
                self.v_extra_info = curr_node.outputs[0].meta["extra_info"]

            self.transpose_seq.append(curr_node)

        if len(self.transpose_seq) > 1:
            return True
        return False

    def rewrite(self, node: ir.Node) -> bool:
        rank = len(get_value_numeric_shape(node.inputs[0]))
        merged_perm = list(range(rank))

        for transpose_node in self.transpose_seq[::-1]:
            curr_perm = transpose_node.attributes["perm"].as_ints()
            for i in range(rank):
                merged_perm[i] = curr_perm[merged_perm[i]]
        if merged_perm == list(range(rank)):
            safe_replace_all_uses_with(self.graph, self.transpose_seq[-1].outputs[0],
                                        self.transpose_seq[0].inputs[0])

        else:
            node_name = self.graph.meta["extra_info"].get_unique_name_with_suffix(
                node.name, ".merged")
            output_name = self.graph.meta["extra_info"].get_unique_name_with_suffix(
                self.transpose_seq[-1].outputs[0].name, ".merged")
            new_node = ir.Node("", "Transpose", [
                               node.inputs[0]], name=node_name)
            new_node.outputs[0].name = output_name
            new_node.outputs[0].shape = self.transpose_seq[-1].outputs[0].shape
            new_node.outputs[0].dtype = self.transpose_seq[-1].outputs[0].dtype
            new_node.attributes["perm"] = ir.AttrInt64s("perm", merged_perm)
            new_node.outputs[0].meta["extra_info"] = self.v_extra_info.copy(
                ignore_safetensors=True)
            self.graph.insert_before(self.transpose_seq[-1], new_node)
            self.graph.meta["extra_info"].record_sharing_encodings(self.transpose_seq[-1].outputs[0].name,
                                                                   new_node.outputs[0].name,
                                                                   self.get_curr_pass_name())
            safe_replace_all_uses_with(self.graph, self.transpose_seq[-1].outputs[0], new_node.outputs[0])

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), node.name)

        return True
