
# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass to remove useless concat in the graph
Useless concat is the concat with only one input
"""


from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BasePredicateRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import logger, safe_replace_all_uses_with


class NullConcatRemovalRewriter(BasePredicateRewriter):
    """
    A graph rewriter pass that removes useless concat from the graph.
    """

    def __init__(self, graph: ir.Graph):
        super().__init__(graph)
        self.cst_one_id = None
        self.input_id = None

    def match(self, node: ir.Node) -> bool:
        if node.op_type != "Concat":
            return False
        if len(node.inputs) == 1:
            return True
        return False

    def rewrite(self, node: ir.Node) -> bool:
        assert node.inputs[0] is not None  # check for mypy
        assert node.outputs[0] is not None  # check for mypy

        safe_replace_all_uses_with(
            self.graph,
            node.outputs[0],
            node.inputs[0]
        )
        node.outputs[0].meta["extra_info"].merge(
            node.inputs[0].meta["extra_info"])

        logger.debug("applied pass %s on '%s'",
                      self.get_curr_pass_name(), node.name)

        return True
