# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
"""
This module provides M2sInsertMHASliceAfterQKVMatmul pass for mha2sha ir modification
"""

from typing import Dict, Optional

from onnxscript import ir

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.base_rewriter import M2sBaseRewriter
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .passes.mha2sha.utils import GroupSliceAttrs
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha\
    .utils.utils import (
    ConditionOnValueProducer,
    check_static_shape,
    get_value_numeric_shape,
    logger,
    scan_previous_nearest_candidate,
)


class M2sInsertMHASliceAfterQKVMatmul(M2sBaseRewriter):
    """
    Pass to insert (GroupSlice -> Concat) after the QKVMatmul of every attention block
    """
    def __init__(self,
                 graph: ir.Graph,
                 m2s_head_split_map: Dict[int, int] | None = None,
                 out_batch_size: int | None = None):
        super().__init__(graph)
        self.out_batch_size = out_batch_size
        if m2s_head_split_map is None:
            m2s_head_split_map = {}
        self.m2s_head_split_map = m2s_head_split_map
        self.qkv_matmul: Optional[ir.Node] = None
        self.softmax: Optional[ir.Node] = None

    def match(self, node: ir.Node) -> bool:
        if node.op_type != 'MatMul':
            return False

        # find softmax--(Reshape/Transpose)--> qkv_matmul,
        # the softmax's output should be used immediately as the first input of qkv_matmul
        # update: Transpose/Reshape are allowed between softmax and qkv_matmul

        # search bottom-up
        qkv_matmul = node
        softmax = None
        for candidate_v in scan_previous_nearest_candidate(
            start_values=list(node.inputs),
            check_fn=ConditionOnValueProducer(["Softmax"]),
            ignore_fn=ConditionOnValueProducer(["Reshape", "Transpose"])
        ):
            softmax = candidate_v.producer()

        if softmax is None:
            return False

        # check qkv_matmul output shape
        check_static_shape(qkv_matmul.outputs[0])
        assert qkv_matmul.outputs[0].shape is not None
        if qkv_matmul.outputs[0].shape.rank() not in (3, 4):
            return False

        self.qkv_matmul = qkv_matmul
        self.softmax = softmax
        return True

    def rewrite(self, node: ir.Node) -> bool:
        assert self.qkv_matmul is not None  # check for mypy, definitely true

        # get output shape
        value = self.qkv_matmul.outputs[0]
        qkv_matmul_out_shape = get_value_numeric_shape(value)
        out_rank = len(qkv_matmul_out_shape)

        if out_rank == 4:
            head_axis = 1
            head_num = qkv_matmul_out_shape[head_axis]
            # for batch splitting in the future
            # batch_axis = 0
            # batch_num = qkv_matmul_out_shape[batch_axis]
        elif out_rank == 3:
            head_axis = 0
            head_num = qkv_matmul_out_shape[head_axis]
            # for batch splitting in the future
            # batch_axis = None
            # batch_num = None
        else:
            assert False

        # split on head
        head_gslice_attrs = GroupSliceAttrs(axis=head_axis)

        if head_num in self.m2s_head_split_map:
            out_head_size = self.m2s_head_split_map[head_num]
        elif -1 in self.m2s_head_split_map:
            out_head_size = self.m2s_head_split_map[-1]
        else:
            out_head_size = 1

        for i, start_i in enumerate(range(0, head_num, out_head_size)):
            end_i = min(start_i + out_head_size, head_num)
            head_gslice_attrs.starts.append(start_i)
            head_gslice_attrs.ends.append(end_i)
            head_gslice_attrs.head_slice_ids.append(i)
            head_gslice_attrs.batch_slice_ids.append(-1)

        assert self.qkv_matmul.name is not None  # check for mypy
        # insert group slice and concat
        _, _ = self.gslice_then_concat(
            value, head_gslice_attrs, self.qkv_matmul.name + "/head_gslice_node")

        assert self.softmax is not None  # check for mypy
        logger.debug(
            "found attention softmax '%s', added GroupSlice->Concat to it", 
            self.softmax.name
        )
        return True
