# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the pass to protect/unprotect the input/output names of the graph
"""
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BaseGraphRewriter


class ProtectIO(BaseGraphRewriter):
    """
    This pass protects the IO from being removed/renamed by other passes
    """

    def __init__(self, graph):
        super().__init__(graph)
        self.origin_output_names = []

    def apply(self):
        raise NotImplementedError()
        # call proctect, unprotect instead

    def protect(self):
        """
        Rename all outputs of the graph to protect them
        """
        self.origin_output_names = []
        for v in self.graph.outputs:
            new_name = self.graph.meta["extra_info"].get_unique_name_with_suffix(
                                                        v.name, ".protect")
            self.origin_output_names.append(v.name)
            self.graph.meta["extra_info"].record_copy(
                v.name, new_name, self.get_curr_pass_name())
            v.name = new_name

    def unprotect(self):
        """
        Rename all outputs of the graph to their original names
        """
        assert len(self.graph.outputs) == len(self.origin_output_names)
        for i, v in enumerate(self.graph.outputs):
            v.name = self.origin_output_names[i]
            self.graph.meta["extra_info"].record_copy(v.name,
                                                      self.origin_output_names[i],
                                                      self.get_curr_pass_name())
