# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================

import copy
import math
from collections import deque

import onnx
import onnx_graphsurgeon as gs
import rich
from numpy import append
from transformers.models.rag.retrieval_rag import Index

from qti.aisw.tools.core.utilities.framework.frameworks.onnx.onnx_model_helper import OnnxModelHelper
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.graph_manager import GraphManager
from qti.aisw.tools.core.utilities.qairt_logging import LogAreas, QAIRTLogger

from .pretty_print import (
    PrettyPrintConstants,
    bold_text,
    create_rich_table,
)
from .utils import validate_splits


def split_onnx(
    model: onnx.ModelProto,
    *,
    num_splits: int,
    split_embedding: bool = False,
    split_lm_head: bool = False,
    skip_verification: bool = True,
    log_level: str = "info",
) -> list[onnx.ModelProto]:
    """
    Splits the given ONNX model into multiple sub-models.

    This function splits the model in-place into the specified number of sub-models.
    It supports splitting embeddings and language model heads.

    Args:
        model (onnx.ModelProto): The ONNX model to be split.
    Keyword-only Args:
        num_splits (int): The number of splits to be made.
        split_embedding (bool): If True, splits the embeddings. Default is False.
        split_lm_head (bool): If True, splits the language model head. Default is False.
        skip_verification (bool): If True, skips ONNX Runtime verification. Default is True.
        log_level (str): The logging level to be used during the splitting. Default is "info".

    Model Topology:
               │ ←─────────  layers[0]  ────────────→ │       │ ←─────────  layers[-1]  ───────────-----─→ │
               │                                      │       │                                            │
    embed  ────┬─────────── add 0 ─┬────────── add 1 ──  ┄┄ ┄─┬─────────────── add(n-2) ─┬──────────── add(n-1) ─── lmhead
             ↑ └─ norm ─ attn ─┘   └─ norm ─ ffn ─┘   ↑       ↑ └─ norm ─ attn ─┘        └─ norm ─ ffn ─┘  ↑
             │                                        │       │                                            │
             │                                        │       │                                            │
            valid splitting points
    """

    splitter_log_area = LogAreas.register_log_area("onnx_splitter")

    logger = QAIRTLogger.register_area_logger(
        splitter_log_area, level=log_level, formatter_val="extended", handler_list=["dev_console"]
    )

    try:
        model = OnnxModelHelper.symbolic_shape_inference(model)
    except (ImportError, onnx.checker.ValidationError) as e:
        # If either onnxruntime is not available(ImportError)
        # Or if external data is not loaded for the model (onnx.checker.ValidationError)
        model = OnnxModelHelper.shape_inference(model)

    graph = gs.import_onnx(model)
    node_idx = {node.name: i for i, node in enumerate(graph.nodes)}

    def can_visit(src, dst):
        dst_idx = node_idx[dst.name]
        if node_idx[src.name] > dst_idx:
            return False

        queue = deque([src])
        while queue:
            curr = queue.popleft()
            if curr == dst:
                return True

            consumers = [c for output in curr.outputs for c in output.outputs if node_idx[c.name] <= dst_idx]
            if dst in consumers:
                return True

            queue.extend(consumers)

        return False

    def get_residual_add(node):
        if node.op != "Add":
            return None
        try:
            a, b = node.i(), node.i(1)
            if a.op == "Add" and can_visit(a, b):
                return a
            elif b.op == "Add" and can_visit(b, a):
                return b
            else:
                return None
        except IndexError:
            return None

    residual_adds = []
    for node in graph.nodes:
        if add0 := get_residual_add(node):
            if len(residual_adds) == 0 or add0 != residual_adds[-1]:
                residual_adds.append(node)

    # lm_head split is a residual add
    n_possible_splits = len(residual_adds) + int(split_embedding) + 1
    if n_possible_splits < num_splits:
        raise ValueError(
            f"Not enough layers in the model to properly split. {n_possible_splits} possible splits, {num_splits} splits requested"
        )

    embedding = []
    border_nodes = []
    graph_inputs = {inp.name: inp for inp in graph.inputs}
    if split_embedding:
        try:
            embedding_name = graph_inputs["input_ids"].outputs[0].name
            embedding_idx = node_idx[embedding_name]
            # NOTE: Hack to keep only embedding in first split
            if embedding_idx != 0 and graph.nodes[0].op == "Pad":
                graph.nodes[0], graph.nodes[1] = graph.nodes[1], graph.nodes[0]
                embedding_idx = 0
        except KeyError:
            for node in graph.nodes:
                if (
                    node.op == "Gather"
                    and not any("lora" in inp.name for inp in node.inputs)
                    and any(inp in graph.inputs for inp in node.inputs)
                ):
                    embedding_idx = node_idx[node.name]
            if not embedding:
                raise ValueError(
                    "`split_embedding` set to True, but no input named input_ids or no model input to a Gather op"
                )
        num_splits -= 1

        embedding.append(embedding_idx)
        border_nodes.append(graph.nodes[embedding_idx])

    border_nodes.extend(residual_adds)

    lm_head = []
    if split_lm_head:
        lm_head.append(node_idx[residual_adds.pop().name])
        num_splits -= 1

    # Not counting split_lm_head as it is the last residual add
    interval = len(residual_adds) / num_splits

    split_nodes = []
    for i in range(1, num_splits):
        idx = math.floor(i * interval)
        residual_add_idx = node_idx[residual_adds[idx].name]
        split_nodes.append(residual_add_idx)

    split_nodes = embedding + split_nodes + lm_head

    splits = []

    for i in range(len(split_nodes) + 1):
        if i == 0:
            start_idx = 0
        else:
            start_idx = split_nodes[i - 1] + 1

        if i == len(split_nodes):
            end_idx = len(graph.nodes)
        else:
            end_idx = split_nodes[i] + 1

        split_inputs, split_outputs = [], []

        for node in graph.nodes[start_idx:end_idx]:
            # Process node inputs
            for inp in node.inputs:
                # Graph inputs
                if inp in graph.inputs:
                    if inp not in split_inputs:
                        split_inputs.append(inp)

                # Add nodes from previous splits

            for out in node.outputs:
                if out in graph.outputs or node in border_nodes:
                    if out not in split_outputs:
                        split_outputs.append(out)

        split_model_nodes = [n for n in graph.nodes[start_idx:end_idx]]

        split = gs.Graph(
            nodes=split_model_nodes,
            inputs=split_inputs,
            outputs=split_outputs,
            name=f"split_{i + 1}",
            import_domains=graph.import_domains,
        )

        producer_tensors = {
            out.name
            for node in graph.nodes[:start_idx]
            for out in node.outputs
            if not GraphManager.is_constant_tensor(out)
        }

        for node in split.nodes.copy():
            for inp in [_inp for _inp in node.inputs if _inp.name in producer_tensors]:
                try:
                    start_tensor = inp
                    producer_node = start_tensor.inputs[0]

                    if GraphManager.is_constant_tensor(inp):
                        split.nodes.insert(0, producer_node)

                    elif producer_node in border_nodes:
                        if start_tensor not in split.inputs:
                            split.inputs.append(start_tensor)
                    else:
                        while start_tensor not in graph.inputs:
                            producer_node = start_tensor.inputs[0]

                            if producer_node not in split.nodes:
                                split.nodes.insert(0, producer_node)

                            act_parent_idx = 0
                            try:
                                while GraphManager.is_constant_tensor(start_tensor.i(act_parent_idx)):
                                    act_parent_idx += 1
                                start_tensor = start_tensor.i(act_parent_idx)
                            except IndexError:
                                break

                        if start_tensor in graph.inputs and start_tensor not in split.inputs:
                            split.inputs.append(start_tensor)

                except IndexError:
                    continue

        splits.append(split)

    if log_level == "debug":
        table = create_rich_table(
            title=bold_text("Model Splitting Results:", color=PrettyPrintConstants.Q_BLUE),
            headers=["Split Number", "New Inputs", "New Outputs"],
            positions=[0.18, 0.59, 1.0],
            alignment=["left", "left", "left"],
        )
        for i, split in enumerate(splits):
            table.add_row(
                str(i + 1),
                ", ".join([input.name for input in split.inputs]),
                ", ".join([output.name for output in split.outputs]),
            )

        console = rich.console.Console(highlight=True)
        console.print(table, overflow="fold")

    splits = [gs.export_onnx(split) for split in splits]

    if not skip_verification:
        validate_splits(model, splits, logger=logger)

    return splits
