import copy
import math
from collections import deque

import rich
from onnxscript import ir
from onnxscript.optimizer import remove_unused_nodes

from qairt.utils import loggers
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.split.pretty_print import (
    PrettyPrintConstants,
    bold_text,
    create_rich_table,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.split.utils import validate_splits
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.utils.encodings import (
    EncKind,
)

_logger = loggers.get_logger(name=__name__)


def split(
    model: ir.Model,
    num_splits: int,
    split_embedding: bool = False,
    split_lm_head: bool = False,
    skip_verification: bool = True,
    log_level: str = "info",
):
    """
    Splits the given ONNX model into multiple sub-models.

    This function splits the model in-place into the specified number of sub-models.
    It supports splitting embeddings and language model heads.

    Args:
        model (onnx.ModelProto): The ONNX model to be split
    Keyword-only Args:
        num_splits (int): The number of splits to be made
        split_embedding (bool): If True, splits the embeddings. Default is False
        split_lm_head (bool): If True, splits the language model head. Default is False
        skip_verification (bool): If True, skips ONNX Runtime verification. Default is False
        log_level (str): The logging level to be used during the splitting. Default is "info"

    Model Topology:
               │ ←─────────  layers[0]  ────────────→ │       │ ←─────────  layers[-1]  ───────────-----─→ │
               │                                      │       │                                            │
    embed  ────┬─────────── add 0 ─┬────────── add 1 ──  ┄┄ ┄─┬─────────────── add(n-2) ─┬──────────── add(n-1) ─── lmhead
             ↑ └─ norm ─ attn ─┘   └─ norm ─ ffn ─┘   ↑       ↑ └─ norm ─ attn ─┘        └─ norm ─ ffn ─┘  ↑
             │                                        │       │                                            │
             │                                        │       │                                            │
            valid splitting points
    """

    if num_splits <= 1:
        _logger.debug("skip split because num_splits=%d", num_splits)
        return [model]

    original_num_splits = num_splits

    node_idx = {node.name: i for i, node in enumerate(model.graph)}

    def can_visit(src: ir.Node, dst: ir.Node) -> bool:
        """Whether node 'dst' can be visited starting from node 'src'"""
        q = deque([src])

        while q:
            curr = q.popleft()

            if node_idx[curr.name] > node_idx[dst.name]:
                return False

            if curr is src:
                consumers = [node for node in curr.outputs[0].consumers() if node is not dst]
            else:
                consumers = [node for node in curr.outputs[0].consumers()]

            if dst in consumers:
                return True

            q.extend(consumers)

        return False

    def is_residual_add(node: ir.Node) -> bool:
        """
        Return True if 'node' is a residual add
        """

        def is_non_constant_input(value: ir.Value) -> bool:
            """
            Return True if 'value' is neither a constant input nor a graph input
            'value' is Constant if either 'value.const_value' is valid or if the producer op is Constant
            'value' is graph input if value.producer() is None
            """
            if value.const_value is not None:
                return False

            producer = value.producer()
            return producer is not None and producer.op_type != "Constant"

        if node.op_type == "Add":
            input1 = node.inputs[0]
            input2 = node.inputs[1]

            if input1 is None or input2 is None:
                return False

            if not is_non_constant_input(input1) or not is_non_constant_input(input2):
                return False

            # An essential condition for a residual add is that
            # One of its inputs is consumed by multiple nodes
            if len(input1.consumers()) == 1 and len(input2.uses()) > 1:
                # To satisfy mypy
                input2_producer = input2.producer()
                assert input2_producer is not None

                return can_visit(input2_producer, node)

            elif len(input1.consumers()) > 1 and len(input2.uses()) == 1:
                # To satisfy mypy
                input1_producer = input1.producer()
                assert input1_producer is not None

                return can_visit(input1_producer, node)

        return False

    def copy_subgraph(start_tensor: ir.Value, end_tensor: ir.Value, name: str = "graph") -> ir.Graph:
        """
        Create a new ir.Graph based on the start_tensor end_tensor

        Given a start_tensor and an end tensor as the input and output of the subgraph,
        this function will replicate all the nodes, its inputs, outputs and initializers in the subgraph
        Any graph inputs/outputs between these nodes will also be the inputs/outputs of the subgraph
        """

        def create_new_value(value: ir.Value) -> ir.Value:
            """
            Create a new ir.Value instance from the passed 'value'
            This is required to create new ir.Value objects(tensors) in the new split graph
            as ir.Value cannot be used across graphs
            """
            new_value = ir.Value(name=value.name, shape=value.shape, type=value.type)

            # If the value is an initializer, copy the constant value
            if value.is_initializer():
                new_value.const_value = value.const_value

            # If the value object has some meta data, copy it to the new ir.Value object
            # The metadata might include
            # 1. named encodings (base encodings + encodings for each lora adapter)
            # 2. named safetensors (weights for each lora adapter)
            # 3. Whether the tensor is updatable (lora tensor)
            if "extra_info" in value.meta:
                new_value.meta["extra_info"] = value.meta["extra_info"].copy()

            return new_value

        def create_new_node(node: ir.Node, node_inputs: list[ir.Value], node_outputs: list[ir.Value]):
            """
            Create a new ir.Node instance from the passed 'node' object and its inputs and outputs
            This is required to create new nodes in the new split graph
            as nodes cannot be used across graphs
            """
            new_node = ir.Node(
                domain=node.domain,
                inputs=node_inputs,
                outputs=node_outputs,
                op_type=node.op_type,
                attributes=node.attributes,
                version=node.version,
                name=node.name,
            )

            return new_node

        value_map = {}

        start_index = node_idx[start_tensor.consumers()[0].name]

        # To satisfy mypy
        end_tensor_producer = end_tensor.producer()
        assert end_tensor_producer is not None

        end_index = node_idx[end_tensor_producer.name] + 1

        subgraph_nodes = []
        subgraph_inputs = []
        subgraph_outputs = []
        subgraph_initializers = []

        # Replicate the node and add it to 'subgraph_nodes' to create a split model
        for node in model.graph[start_index:end_index]:
            node_inputs = []
            for inp in node.inputs:
                if inp:
                    if inp.name not in value_map:
                        new_value = create_new_value(inp)

                        if inp.is_graph_input() or inp is start_tensor:
                            subgraph_inputs.append(new_value)
                        elif inp.is_initializer():
                            subgraph_initializers.append(new_value)

                        value_map[inp.name] = new_value

                    node_inputs.append(value_map[inp.name])

            node_outputs = []
            for out in node.outputs:
                if out.name not in value_map:
                    new_value = create_new_value(out)

                    if out.is_graph_output() or out is end_tensor:
                        subgraph_outputs.append(new_value)

                    value_map[out.name] = new_value

                node_outputs.append(value_map[out.name])

            new_node = create_new_node(node, node_inputs, node_outputs)

            subgraph_nodes.append(new_node)

        # If an input to any node in the graph is
        #       1. Neither a graph or split input
        #       2. Nor a part of the split
        # Then we need to rebuild that subgraph in the current split
        # Examples of such instances are
        # 1. ['lora_alpha' -> Pad -> Reshape] -> Gather
        # 2. Alibi positional encodings

        visited = set()
        all_node_outputs = set()

        for node in subgraph_nodes:
            all_node_outputs.update([out.name for out in node.outputs])

        for node in model.graph[start_index:end_index]:
            for inp in node.inputs:
                if (
                    inp
                    and value_map[inp.name] not in subgraph_inputs
                    and inp.producer()
                    and inp.name not in all_node_outputs
                ):
                    # Do a breadth-first search in reverse to construct the subgraph
                    q = deque([inp.producer()])

                    while q:
                        curr = q.popleft()

                        if curr and curr not in visited:
                            visited.add(curr)

                            q.extend([inp.producer() for inp in curr.inputs if inp and inp.producer()])

                            # Replicate node 'curr' if it is not in subgraph_nodes
                            # but if its output is consumed by any node in subgraph_nodes list
                            node_inputs = []
                            for curr_inp in curr.inputs:
                                if curr_inp:
                                    if curr_inp and curr_inp.name not in value_map:
                                        new_value = create_new_value(curr_inp)

                                        if curr_inp.is_graph_input():
                                            subgraph_inputs.append(new_value)
                                        elif curr_inp.is_initializer():
                                            subgraph_initializers.append(new_value)

                                        value_map[curr_inp.name] = new_value

                                    node_inputs.append(value_map[curr_inp.name])

                            node_outputs = []
                            for out in curr.outputs:
                                if out is not inp and out.name not in value_map:
                                    new_value = create_new_value(out)

                                    if out.is_graph_output():
                                        subgraph_outputs.append(new_value)

                                    value_map[out.name] = new_value

                                node_outputs.append(value_map[out.name])

                            all_node_outputs.update([out.name for out in node_outputs])

                            new_node = create_new_node(curr, node_inputs, node_outputs)
                            subgraph_nodes.insert(0, new_node)

        # When constructing a subgraph, nodes are always inserted at index 0
        # This might result in nodes getting added in non-topological ordering
        # To preserve topological sorting, sort the list of nodes
        node2idx = {_node.name: idx for idx, _node in enumerate(model.graph)}
        subgraph_nodes.sort(key=lambda _node: node2idx[_node.name])

        split_graph = ir.Graph(
            name=name,
            inputs=subgraph_inputs,
            outputs=subgraph_outputs,
            nodes=subgraph_nodes,
            initializers=subgraph_initializers,
            opset_imports=model.graph.opset_imports,
        )

        return split_graph

    residual_adds = [node.outputs[0] for node in model.graph if is_residual_add(node)]

    # NOTE: Not all residual adds are valid split points
    # In an LLM, the usual structure is "Attention 1 -> Residual Add 1 -> FF network -> Residual Add 2 -> Attention 2"
    # Valid split tensor in the above structure is "Residual Add 2"
    residual_adds = residual_adds[1::2]

    n_possible_splits = len(residual_adds) + int(split_embedding) + 1

    if n_possible_splits < num_splits:
        raise ValueError(
            f"Not enough layers in the model to properly split. {n_possible_splits} possible splits, {num_splits} splits requested"
        )

    input_ids = None
    try:
        input_ids = [inp for inp in model.graph.inputs if inp.name == "input_ids"][0]
    except IndexError:
        for node in model.graph:
            if (
                node.op_type == "Gather"
                and not any(inp and inp.name and "lora" in inp.name for inp in node.inputs)
                and any(inp in model.graph.inputs for inp in node.inputs)
            ):
                input_ids = node.inputs[0]

    if input_ids is None:
        first_input = model.graph.inputs[0]
        _logger.warning(
            f"Unable to find 'input_ids' graph input. Using the first graph input: {first_input.name}"
        )
        input_ids = first_input

    embeddings = None
    if split_embedding:
        if not input_ids:
            raise ValueError(
                "`split_embedding` set to True, but no input named 'input_ids' or no model input to a Gather op"
            )

        input_ids_consumer = input_ids.consumers()[0]
        while input_ids_consumer.op_type != "Gather":
            input_ids_consumer = input_ids_consumer.outputs[0].consumers()[0]

        embeddings = input_ids_consumer.outputs[0]

        # NOTE: Hack to keep only the embedding in first split
        if model.graph[0].op_type == "Pad":
            model.graph.insert_before(model.graph[0], [model.graph[1]])
        num_splits -= 1

    lm_head = None
    if split_lm_head:
        lm_head = residual_adds.pop()
        num_splits -= 1

    interval = len(residual_adds) / num_splits
    split_tensors = [input_ids]

    if split_embedding:
        split_tensors.append(embeddings)

    for i in range(1, num_splits):
        idx = math.floor(i * interval)
        split_tensors.append(residual_adds[idx])

    if split_lm_head:
        split_tensors.append(lm_head)

    # Add the output of the last node
    split_tensors.append(model.graph[-1].outputs[0])

    # Fill missing encodings of boundary tensors
    for boundary_tensor in split_tensors[1:]:
        if "extra_info" in boundary_tensor.meta:
            enc = boundary_tensor.meta["extra_info"].named_encodings

            if not enc:
                start_tensor = boundary_tensor

                while not enc:
                    # To satisfy mypy
                    assert start_tensor is not None

                    producer = start_tensor.producer()

                    if producer:
                        start_tensor = producer.inputs[0]

                        if start_tensor and "extra_info" in start_tensor.meta:
                            enc = start_tensor.meta["extra_info"].named_encodings

                        else:
                            # No extra_info, and hence, no encodings present for the tensor
                            continue

                        if enc:
                            for encset_name, v_enc in enc.items():
                                enc_copy = copy.deepcopy(v_enc)

                                # NOTE: In instances where param encodings are copied to an activation tensor
                                # it is important to set the enc_kind of the copied tensor to ACTIVATION
                                # Notable when copying the encodings of "Gather.weights" (param) to "Gather_output" (activation)
                                # When split_embeddding is True
                                enc_copy.enc_kind = EncKind.ACTIVATION
                                boundary_tensor.meta["extra_info"].named_encodings[encset_name] = enc_copy

                            _logger.info(
                                f"Copied the encodings of {start_tensor.name} to {boundary_tensor.name}"
                            )
                            break

                    else:
                        # Graph input reached
                        break

    splits: list[ir.Model] = []

    # Re-build the node_idx cache if we have modified "lora_alpha" and "input_ids"
    node_idx = {node.name: i for i, node in enumerate(model.graph)}

    for i in range(len(split_tensors) - 1):
        subgraph_start_tensor = split_tensors[i]
        subgraph_end_tensor = split_tensors[i + 1]

        split_graph = copy_subgraph(
            subgraph_start_tensor, subgraph_end_tensor, name=f"{i + 1}_of_{original_num_splits}"
        )

        split_model = ir.Model(
            graph=split_graph,
            ir_version=model.ir_version,
            domain=model.domain,
            model_version=model.model_version,
        )

        # In instances where nodes are added to a split but they are not used
        # because they are not used by any nodes in the current split
        # But in other splits (in which case, such subgraphs are re-constructed in the other splits)
        # Remove such nodes in the current split
        remove_unused_nodes(split_model)

        splits.append(split_model)

    if log_level == "debug":
        table = create_rich_table(
            title=bold_text("Model Splitting Results:", color=PrettyPrintConstants.Q_BLUE),
            headers=["Split Number", "New Inputs", "New Outputs"],
            positions=[0.18, 0.59, 1.0],
            alignment=["left", "left", "left"],
        )
        for i, _split in enumerate(splits):
            table.add_row(
                str(i + 1),
                ", ".join([_input.name for _input in _split.graph.inputs if _input and _input.name]),
                ", ".join([_output.name for _output in _split.graph.outputs if _output and _output.name]),
            )

        console = rich.console.Console(highlight=True)
        console.print(table, overflow="fold")

    if not skip_verification:
        validate_splits(
            ir.serde.serialize_model(model), [ir.serde.serialize_model(s) for s in splits], logger=_logger
        )

    return splits
