# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass for merge a sequence of reshape ops 
into a single reshape op.
"""


from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BasePredicateRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.ir_extra_info import VariableExtraInfo
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    get_value_numeric_shape,
    have_static_shape_on_node_io,
    logger,
    safe_replace_all_uses_with,
)


class MergeSequenceReshapeOps(BasePredicateRewriter):
    '''
    Transform subgraph:
        Subgraph(in_a) --> c
        {
            # allow in_a/in_a1/in_a2... have multiple users

            in_a1 = Reshape(in_a)
            in_a2 = Reshape(in_a2)
            ...
            c = Reshape(in_aX)
        }
    Into:
        Subgraph(in_a) --> c
        {
            c = Reshape(in_a)
        }
    '''
    def __init__(self, graph):
        super().__init__(graph)
        self.reshape_seq = []
        self.v_extra_info = VariableExtraInfo()

    def match(self, node: ir.Node) -> bool:
        if node.op_type != "Reshape":
            return False
        if not have_static_shape_on_node_io(node):
            return False
        self.reshape_seq = [node]
        self.v_extra_info = node.outputs[0].meta["extra_info"]
        # find bottom-up
        curr_node = node
        while True:
            curr_node_input = curr_node.inputs[0]

            # check for mypy, definitely true, since curr_node can only be Reshape
            assert curr_node_input is not None

            producer = curr_node_input.producer()
            if producer is None:
                break
            if producer.op_type != "Reshape":
                break
            curr_node = producer

            # do not merge reshape if encodings are not equal
            if self.v_extra_info.defined_encodings() and \
                    curr_node.outputs[0].meta["extra_info"].defined_encodings():
                if self.v_extra_info != curr_node.outputs[0].meta["extra_info"]:
                    break
            if not self.v_extra_info.defined_encodings() and \
                    curr_node.outputs[0].meta["extra_info"].defined_encodings():
                self.v_extra_info = curr_node.outputs[0].meta["extra_info"]

            self.reshape_seq.insert(0, curr_node)

        if len(self.reshape_seq) > 1:
            return True
        return False

    def rewrite(self, node: ir.Node) -> bool:
        input_shape = get_value_numeric_shape(self.reshape_seq[0].inputs[0])
        output_shape = get_value_numeric_shape(self.reshape_seq[-1].outputs[0])
        if tuple(input_shape) == tuple(output_shape):
            # reshape is useless, remove it
            safe_replace_all_uses_with(
                self.graph,
                self.reshape_seq[-1].outputs[0],
                self.reshape_seq[0].inputs[0],
                except_users=self.reshape_seq
            )
        else:
            self.reshape_seq[-1].replace_input_with(
                0, self.reshape_seq[0].inputs[0]
            )
        self.reshape_seq[-1].outputs[0].meta["extra_info"].merge(
            self.v_extra_info
        )
        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), node.name)

        return True
