# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the entry pass for layout optimization
"""

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BaseGraphRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.cleaning.merge_reshape import MergeSequenceReshapeOps
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.cleaning.merge_transposes import MergeSequenceTransposeOps
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.cleaning.remove_dead_nodes import DeadCodeRemovalRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.cleaning.remove_unused_weights import DeadWeightRemovalRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.opt.layout_opt.layout_binelewise_rewriter import LayoutBinelewiseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.opt.layout_opt.layout_concat_after_qkvmatmuls_rewriter import (
    LayoutConcatAfterQKVMatmulsRewriter,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.opt.layout_opt.protect_layout_sensitive_ops import (
    ProtectLayoutSensitiveOps,
    UnProtectLayoutSensitiveOps,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.opt.layout_opt.simplify_concat_transpose import SimplifyConcatTransposeRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.opt.layout_opt.simplify_reshape_transpose_seq import (
    SimplifyReshapeTransposeSeqRewriter,
)


class LayoutOptRewriter(BaseGraphRewriter):
    """
    The entry pass for layout optimization
    """

    def apply(self):

        total_rewrite_count = 0
        total_rewrite_count += ProtectLayoutSensitiveOps(self.graph).apply()
        total_rewrite_count += DeadCodeRemovalRewriter(self.graph).apply()

        total_rewrite_count += SimplifyReshapeTransposeSeqRewriter(
            self.graph).apply()
        total_rewrite_count += DeadCodeRemovalRewriter(self.graph).apply()

        # a SimplifyReshapeTransposeSeqRewriter pass is required
        # before apply LayoutBinelewiseRewriter, to not floating bin elewise above unnecessarily
        total_rewrite_count += LayoutBinelewiseRewriter(self.graph).apply()
        total_rewrite_count += DeadCodeRemovalRewriter(self.graph).apply()

        total_rewrite_count += LayoutConcatAfterQKVMatmulsRewriter(self.graph).apply()
        total_rewrite_count += DeadCodeRemovalRewriter(self.graph).apply()

        total_rewrite_count += SimplifyReshapeTransposeSeqRewriter(self.graph).apply()
        total_rewrite_count += DeadCodeRemovalRewriter(self.graph).apply()

        total_rewrite_count += SimplifyConcatTransposeRewriter(self.graph).apply()
        total_rewrite_count += DeadCodeRemovalRewriter(self.graph).apply()

        total_rewrite_count += UnProtectLayoutSensitiveOps(self.graph).apply()
        total_rewrite_count += DeadCodeRemovalRewriter(self.graph).apply()

        total_rewrite_count += MergeSequenceTransposeOps(self.graph).apply()
        total_rewrite_count += MergeSequenceReshapeOps(self.graph).apply()

        total_rewrite_count += DeadCodeRemovalRewriter(self.graph).apply()
        total_rewrite_count += DeadWeightRemovalRewriter(self.graph).apply()
        return total_rewrite_count
