# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides the entry pass for mha2sha optimization
"""
from typing import Dict, List

from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .base.rewriter import BaseGraphRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.cleaning.remove_dead_nodes import (
    DeadCodeRemovalRewriter,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.cleaning.remove_null_concat import (
    NullConcatRemovalRewriter,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.cleaning.remove_null_mul import (
    NullMulRemovalRewriter,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.cleaning.remove_unused_weights import (
    DeadWeightRemovalRewriter,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.fold_init_gslice import (
    M2sFoldInitGroupSlice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.insert_mhaslices_after_qkv_matmul import (
    M2sInsertMHASliceAfterQKVMatmul,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_binelewise_gslice import (
    M2sReorderBinElewiseGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_clip_gslice import (
    M2sReorderClipGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_concat_gslice import (
    M2sReorderConcatGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_conv_gslice import (
    M2sReorderConvGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_expand_gslice import (
    M2sReorderExpandGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_instancenorm_gslice import (
    M2sReorderInstancenormGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_layernorm_gslice import (
    M2sReorderLayernormGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_fasthadamardtransform_gslice import (
    M2sReorderFastHadamardTransformGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_matmul_gslice import (
    M2sReorderMatmulGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_reduce_gslice import (
    M2sReorderUnaryReduceGroupslice_Passes,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_reshape_gslice import (
    M2sReorderReshapeGroupSlice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_scatterelements_gslice import (
    M2sReorderScatterElementsGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_slice_gslice import (
    M2sReorderSliceGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_softmax_gslice import (
    M2sReorderSoftmaxGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_spacetodepth_gslice import (
    M2sReorderSpaceToDepthGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_transpose_gslice import (
    M2sReorderTransposeGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_unary_gslice import (
    M2sReorderUnaryGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.reorder_where_gslice import (
    M2sReorderWhereGroupslice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.replace_gslice import (
    M2sReplaceGroupSlice,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.replace_split_gslice_with_slice_gslice import (
    M2sReplaceSplitGroupslice2SliceGroupslice,
)


class MHA2SHARewriter(BaseGraphRewriter):
    """
    The entry pass for mha2sha optimization
    It has these steps:
    - pre_stage: 
        - Insert GroupSlice->Concat after QKV_Matmul of every attention block
    - proc_stage: 
        - Reorder every possible (X->GroupSlice) pattern to (GroupSlice->X) in a loop
    - post_stage:
        - clean the graph
    """
    def __init__(self, graph: ir.Graph, m2s_head_split_map: Dict[int, int] | None = None):
        super().__init__(graph)
        self.pre_stages = [
            M2sInsertMHASliceAfterQKVMatmul(self.graph, m2s_head_split_map)
        ]
        self.proc_stages: List[BaseGraphRewriter] = [
            M2sReorderMatmulGroupslice(self.graph),
            M2sReorderSoftmaxGroupslice(self.graph),
            M2sReorderLayernormGroupslice(self.graph),
            M2sReorderInstancenormGroupslice(self.graph),
            *[pass_class(self.graph)
              for pass_class in M2sReorderUnaryReduceGroupslice_Passes],
            M2sReorderBinElewiseGroupslice(self.graph),
            M2sReorderConcatGroupslice(self.graph),
            M2sReorderSliceGroupslice(self.graph),
            M2sReorderTransposeGroupslice(self.graph),
            M2sReorderReshapeGroupSlice(self.graph),
            M2sReorderConvGroupslice(self.graph),
            M2sReorderUnaryGroupslice(self.graph),
            M2sReplaceSplitGroupslice2SliceGroupslice(self.graph),
            M2sReorderExpandGroupslice(self.graph),
            M2sReorderWhereGroupslice(self.graph),
            M2sReorderScatterElementsGroupslice(self.graph),
            M2sReorderSpaceToDepthGroupslice(self.graph),
            M2sReorderClipGroupslice(self.graph),
            M2sReorderFastHadamardTransformGroupslice(self.graph),
        ]
        self.post_stages = [
            DeadCodeRemovalRewriter(self.graph),
            M2sFoldInitGroupSlice(self.graph),
            M2sReplaceGroupSlice(self.graph),
            NullMulRemovalRewriter(self.graph),
            NullConcatRemovalRewriter(self.graph),
            DeadCodeRemovalRewriter(self.graph),
            DeadWeightRemovalRewriter(self.graph),
        ]
        self.m2s_head_split_map = m2s_head_split_map

    def apply(self):
        total_rewrite_count = 0
        for rewriter in self.pre_stages:
            total_rewrite_count += rewriter.apply()

        curr_rewrite_count = 1
        loop_count = 0
        loop_max = 10000
        while curr_rewrite_count > 0 and loop_count < loop_max:
            curr_rewrite_count = 0
            for proc_rewriter in self.proc_stages:
                curr_rewrite_count += proc_rewriter.apply()

                # for debug use
                # VerifyPass(self.graph).apply()
            loop_count += 1
            total_rewrite_count += curr_rewrite_count

            # # debug
            # if (total_rewrite_count > 18):
            #     break

        if loop_count == loop_max:
            assert False, "dead loop, something wrong"

        for post_rewriter in self.post_stages:
            total_rewrite_count += post_rewriter.apply()

        return total_rewrite_count
