# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass to remove all dead node in the graph
"""
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BaseGraphRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import is_used, logger


class DeadCodeRemovalRewriter(BaseGraphRewriter):
    """
    A graph rewriter pass that removes dead code from the graph.

    Dead code is defined as nodes that have no outputs or nodes 
    that are not reachable from the graph's outputs.
    """

    def apply(self):
        """
        Removes dead code from the graph.
        Assume self.graph.nodes are topologically sorted

        Returns the number of nodes removed.
        """
        removed_count = 0
        for node in reversed(self.graph):
            if all(not is_used(v) for v in node.outputs):
                self.graph.remove(node, safe=True)

                logger.debug("Removed dead node '%s'", node.name)
                removed_count += 1
        return removed_count
