# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass to remove dead weight in the graph
"""
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BaseGraphRewriter


class DeadWeightRemovalRewriter(BaseGraphRewriter):
    """
    A graph rewriter pass that removes dead weight from the graph.

    """

    def apply(self):
        """
        Removes dead weight from the graph.

        Returns the number of weights removed.
        """
        value_nameset = set()
        for node in self.graph:
            for v in node.inputs:
                if v is not None:
                    value_nameset.add(v.name)
        values_to_remove = set(
            x for x in self.graph.initializers) - value_nameset
        for v_name in values_to_remove:
            del self.graph.initializers[v_name]

            # logger.debug(f"Removed dead weight '{v}'")
        return len(values_to_remove)
