# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================

"""
Encapsulate an ONNX model and its associated encodings in a wrapper OnnxModel class to
1. Run transformations
2. Split the onnx model
3. Export the model after running transformations

Example 1 - MHA2SHA

    from qti.aisw.tools.core.utilities.framework.onnx import OnnxModel

    onnx_model = OnnxModel.load(model_path="model.onnx", encodings_path="model.encodings")
    onnx_model.mha2sha()
    onnx_model.export(path="/path/to/export/dir", prefix="model_sha")


Example 2 - Split API

    from qti.aisw.tools.core.utilities.framework.onnx import OnnxModel
    from qairt.api.configs.common import BackendType
    from qairt.api.transformer.model_transformer_config import QuantizationStage

    onnx_model = OnnxModel.load(model_path="model.onnx", encodings_path="model.encodings")
    splits = onnx_model.split(
        num_splits=3,
        split_embedding=False,
        split_lm_head=False,
    )

    for idx, _split in enumerate(splits):
        _split.export(path="/path/to/export/dir", prefix=f"model_sha_{idx+1}_of_{len(splits)})

"""

import copy
import json
import os
import pathlib
import tempfile
import traceback
from typing import Any, Dict, List, Optional, Union

import numpy as np
import onnx_graphsurgeon as gs
import yaml
from pydantic import BaseModel, FilePath
from safetensors.numpy import load_file, save_file
from tqdm import tqdm

import onnx
from qairt.api.configs.common import BackendType
from qairt.api.transforms.model_transformer_config import QuantizationStage
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.encodings.encodings import AimetEncodingsFactory
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.onnx_model_helper import OnnxModelHelper
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.graph_manager import GraphManager
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.mha2sha.mha2sha import MHAPattern
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.split.split import split_onnx
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.utils.utils import autocomplete_opset
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha import (
    apply_mha2sha_optimization_in_memory,
)
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.v2.mha2sha.optimizer import (
    GraphOptimizer,
)
from qti.aisw.tools.core.utilities.qairt_logging import LogAreas, QAIRTLogger


class LoraAdapter:
    def __init__(self, name: str, weights: dict[str, np.ndarray], encodings: dict | None = None) -> None:
        self.name = name
        self.weights = weights
        self._aimet_encodings = AimetEncodingsFactory.from_dict(encodings) if encodings else None

    @property
    def encodings(self):
        if self._aimet_encodings:
            return self._aimet_encodings.encodings
        return None

    @classmethod
    def load_adapters_from_yaml(cls, lora_adapters_path: str | os.PathLike | dict):
        """Parse lora_importer_config.yaml"""
        if isinstance(lora_adapters_path, dict):
            lora_adapter_list = lora_adapters_path["use_case"]
        else:
            with open(lora_adapters_path, "r") as file:
                lora_adapter_list = yaml.safe_load(file)["use_case"]

        adapters = []
        for adapter in lora_adapter_list:
            try:
                with open(adapter["quant_overrides"]) as f:
                    encodings = json.load(f)
            except KeyError:
                encodings = None
            lora_adapter = cls(adapter["name"], load_file(adapter["lora_weights"]), encodings)
            adapters.append(lora_adapter)

        return adapters

    def map(self, original: str, slices: list[str], model_tensors: set[str] | list[str]):
        """Map the safetensor and encoding of the 'original' tensor to multiple slices

        Args:
            original: Original safetensor name
            slices: List of slice names
            model_tensors: A list or set of tensor names to use as a filter
        """
        if original in self.weights:
            original_weights = self.weights[original]
            n_slices = len(slices)

            for slice in slices:
                n = int(slice.split(":")[-1])

                # ndim == 1: Bias (slice axis 0)
                # ndim == 2: Gemm/MatMul weights (slice axis 0)
                # ndim == 4: Conv (slice axis 1)
                slice_axis = 1 if original_weights.ndim == 4 else 0

                n_dim = original_weights.shape[slice_axis]
                head_dim = n_dim // n_slices
                start = n * head_dim
                end = min((n + 1) * head_dim, n_dim)

                # Slice from start:end at the given axis
                slice_weights = original_weights.take(range(start, end), axis=slice_axis)

                # Map slice tensor to its safetensor weight
                self.set(slice, slice_weights)
                if original not in model_tensors:
                    self.delete(original)

        if self._aimet_encodings:
            self._aimet_encodings._map_slices(original, slices)
            if original not in model_tensors:
                self._aimet_encodings.delete(original)

    def get(self, name: str):
        return self.weights.get(name)

    def set(self, name: str, weights: np.ndarray):
        self.weights[name] = weights

    def delete(self, name):
        try:
            del self.weights[name]
        except KeyError:
            pass  # NoOp


class ExportedUseCase(BaseModel):
    name: str
    safetensors: FilePath
    encodings: Optional[FilePath] = None


class ExportedFiles(BaseModel):
    onnx_path: FilePath
    data_path: FilePath
    encodings_path: Optional[FilePath] = None
    use_cases: list[ExportedUseCase] = []
    lora_tensor_names: Optional[FilePath] = None
    lora_importer_config: Optional[FilePath] = None


class OnnxModel:
    """Manages an Onnx model and optionally, associated encodings

    Attributes:
        model: ONNX model
        encodings: AIMET encodings dictionary

    """

    log_area = LogAreas.register_log_area("OnnxModel")
    logger = QAIRTLogger.register_area_logger(
        log_area, level="INFO", formatter_val="extended", handler_list=["dev_console"]
    )

    def __init__(
        self,
        model: onnx.ModelProto,
        encodings: dict | None = None,
        lora_adapters: list[LoraAdapter] | None = None,
        lora_tensor_names: list[str] | None = None,
        model_base_dir: str | None = None,
        **kwargs,
    ) -> None:
        """Initialize an instance of OnnxModel

        Args:
            model: ONNX model
            encodings (Optional): AIMET encodings dict

        Keyword Args:
            log_level (str): Severity of loggig level. Default is INFO
        """
        if lora_adapters is None:
            lora_adapters = []
        if lora_tensor_names is None:
            lora_tensor_names = []

        if log_level := kwargs.get("log_level"):
            try:
                self.logger.setLevel(log_level.upper())
            except (ValueError, TypeError):
                traceback.print_exc()

        # If path is passed, its much faster to run shape inference before loading the weights
        # So skip shape inferendce here but run it in the load function before loading the weights
        if not kwargs.get("skip_shape_inference", False):
            try:
                model = OnnxModelHelper.symbolic_shape_inference(model)
            except Exception:
                # If either onnxruntime is not available(ImportError)
                # Or if external data is not loaded for the model (onnx.checker.ValidationError)
                model = OnnxModelHelper.shape_inference(model)

        self._graph = GraphManager(model)
        self._aimet_encodings = AimetEncodingsFactory.from_dict(encodings) if encodings else None
        self.lora_adapters = lora_adapters
        self.lora_tensor_names = lora_tensor_names
        self.model_base_dir = model_base_dir

        # only applicable for mha2sha_v2
        self.tracing_info: Dict[str, Any] = {}
        self.special_sha_inputs: Dict[str, Any] = {}

    @property
    def model(self) -> onnx.ModelProto:
        """ONNX model"""
        return gs.export_onnx(self._graph.graph.cleanup().toposort())

    @property
    def encodings(self) -> dict | None:
        """AIMET encodings dictionary"""
        return self._aimet_encodings.encodings if self._aimet_encodings else None

    @classmethod
    def load(
        cls,
        *,
        model_path: str | os.PathLike,
        encodings_path: str | os.PathLike | None = None,
        lora_adapters_path: str | os.PathLike | dict | None = None,
        lora_tensor_names_path: str | os.PathLike | None = None,
        **kwargs,
    ):
        """Load the model and encodings from file and initialize an instance of OnnxModel

        Args:
            model_path: Path to ONNX model
            encodings_path (Optional) : Path to AIMET encodings file. Supported versions are v0.6.1 and v1.0.0
            lora_adapters_path (Optional) : Path to lora adapters yaml file (lora_importer_config)

                The schema for the yaml file should be as follows:
                # Start config
                use_case:  # List of use-cases
                    - name:                 <usecase_1/adapter_1 name>
                      lora_weights:         <path to safetensor file for adapter_1>
                      quant_overrides:      <path to AIMET encodings file for adapter_1>
                    - name:                 <usecase_2/adapter_2 name>
                      lora_weights:         <path to safetensor file for adapter_2>
                      quant_overrides:      <path to AIMET encodings file for adapter_2>
                    ...
            lora_tensor_names_path (Optional): Path to .txt file with updatable lora tensor names

        Keyword Args:
            log_level (str): Severity of loggig level. Default is INFO

        Returns:
            An instance of the OnnxModel class

        """
        model_path = pathlib.Path(model_path)

        if encodings_path:
            encodings_path = pathlib.Path(encodings_path)
            with open(encodings_path, "r") as f:
                encodings = json.load(f)
        else:
            encodings = None

        cls.logger.debug("Loading model")
        model = onnx.load(model_path, load_external_data=False)
        
        autocomplete_opset(model, {"qti.aisw":1}) # the opset version should >=1

        model = onnx.shape_inference.infer_shapes(model)

        # TODO we should remove this to accelerate the processing
        onnx.load_external_data_for_model(model, str(model_path.parent))

        model_base_dir = str(model_path.parent)

        if lora_adapters_path:
            lora_adapters = LoraAdapter.load_adapters_from_yaml(lora_adapters_path)
        else:
            lora_adapters = []

        if lora_tensor_names_path:
            with open(lora_tensor_names_path) as f:
                lora_tensor_names = [l.strip() for l in f.readlines()]
        else:
            lora_tensor_names = []

        return cls(
            model=model,
            encodings=encodings,
            lora_adapters=lora_adapters,
            lora_tensor_names=lora_tensor_names,
            skip_shape_inference=True,
            model_base_dir=model_base_dir,
            **kwargs,
        )

    def export(self, path: str | os.PathLike, prefix: str = "model") -> ExportedFiles:
        """Export model artifacts

        Args:
            path: Directory where the artifacts are to be saved
            prefix: Prefix to model and artifact file names. Defaults to "model"
        """
        export_path = pathlib.Path(path)

        if not export_path.exists():
            os.makedirs(export_path, exist_ok=True)

        elif not export_path.is_dir():
            raise OSError(f"{export_path} is not a directory")

        model_path = export_path / f"{prefix}.onnx"
        data_file = f"{prefix}.data"
        data_location = export_path / data_file

        if data_location.exists():
            os.remove(data_location)

        onnx.save(
            self.model,
            model_path,
            save_as_external_data=True,
            location=data_file,
        )
        self.logger.info(f"Model saved at {model_path.absolute()}")
        exported_files: ExportedFiles = ExportedFiles(onnx_path=model_path, data_path=data_location)

        onnx.shape_inference.infer_shapes_path(model_path, model_path)

        try:
            onnx.checker.check_model(model_path, full_check=True)
            self.logger.debug("ONNX Checker passed!")
        except Exception:
            self.logger.warning(f"ONNX checker failed!")
            traceback.print_exc()
            self.logger.warning("Please re-check the model before using")

        if self._aimet_encodings:
            encodings_path = export_path / f"{prefix}.encodings"
            self.logger.info(f"Encodings saved at {encodings_path.absolute()}")
            with open(encodings_path, "w") as f:
                f.write(json.dumps(self._aimet_encodings.encodings))
            exported_files.encodings_path = encodings_path

        # Export LoRA artifacts if present
        if self.lora_adapters:
            use_cases = []

            for adapter in self.lora_adapters:
                safetensors_path = export_path / f"{adapter.name}.safetensors"
                save_file(adapter.weights, safetensors_path)
                exported_use_case: ExportedUseCase = ExportedUseCase(
                    name=adapter.name, safetensors=safetensors_path
                )
                self.logger.debug(
                    f"Safetensors of adapter '{adapter.name}' saved at {safetensors_path.absolute()}"
                )
                use_case = {
                    "name": adapter.name,
                    "model_name": str(model_path),
                    "lora_weights": str(safetensors_path),
                    "output_path": "../lora_output",
                }

                # Encodings is optional
                if lora_encodings := adapter.encodings:
                    encodings_path = export_path / f"{adapter.name}.encodings"
                    with open(encodings_path, "w") as f:
                        f.write(json.dumps(lora_encodings))
                    self.logger.debug(
                        f"Encodings of adapter '{adapter.name}' saved at {encodings_path.absolute()}"
                    )
                    use_case["quant_overrides"] = str(encodings_path)
                    exported_use_case.encodings = encodings_path

                use_cases.append(use_case)
                exported_files.use_cases.append(exported_use_case)

            lora_config = {"use_case": use_cases}

            lora_config_path = export_path / "lora_importer_config.yaml"
            with open(lora_config_path, "w") as f:
                f.write(yaml.dump(lora_config))
            exported_files.lora_importer_config = lora_config_path

        if self.lora_tensor_names:
            lora_tensor_names_path = export_path / "lora_tensor_names.txt"
            with open(lora_tensor_names_path, "w") as f:
                for tensor in self.lora_tensor_names:
                    f.write(f"{tensor}\n")
            exported_files.lora_tensor_names = lora_tensor_names_path
        return exported_files

    def fold_constants(self):
        """Fold constants in-place and run shape inference"""

        def _const_fold_pass(_model) -> onnx.ModelProto:
            graph = gs.import_onnx(_model)

            graph.fold_constants()

            _model = gs.export_onnx(graph.cleanup().toposort())

            try:
                _model = OnnxModelHelper.symbolic_shape_inference(_model)
            except Exception:
                # If either onnxruntime is not available(ImportError)
                # Or if external data is not loaded for the model (onnx.checker.ValidationError)
                _model = OnnxModelHelper.shape_inference(_model)

            return _model

        def get_num_nodes(_model):
            def _get_num_graph_nodes(graph):
                num_nodes = len(graph.node)
                for node in graph.node:
                    for attr in node.attribute:
                        if attr.type == onnx.AttributeProto.GRAPH:
                            num_nodes += _get_num_graph_nodes(attr.g)
                        elif attr.type == onnx.AttributeProto.GRAPHS:
                            for subgraph in attr.graphs:
                                num_nodes += _get_num_graph_nodes(subgraph)
                return num_nodes

            return _get_num_graph_nodes(_model.graph)

        model = gs.export_onnx(self._graph.graph.cleanup())
        init_num_nodes = get_num_nodes(model)
        prefold_num_nodes = init_num_nodes
        postfold_num_nodes = -1

        pass_num = 0

        while prefold_num_nodes != postfold_num_nodes:
            self.logger.debug(f"Folding Constants | Pass {pass_num + 1}")
            pass_num += 1
            prefold_num_nodes = get_num_nodes(model)

            try:
                model = _const_fold_pass(model)
                self._graph = GraphManager(model)
            except Exception as e:
                self.logger.error(
                    f"Constant folding pass failed. Skipping subsequent passes.\nNote: Error was:\n{e}"
                )
                break
            else:
                postfold_num_nodes = get_num_nodes(model)
                nodes_folded = prefold_num_nodes - postfold_num_nodes
                self.logger.debug(
                    f"Original: {prefold_num_nodes} | After folding: {postfold_num_nodes} | {nodes_folded} Node{'s' if nodes_folded != 1 else ''} folded\n"
                )

        self.logger.debug(f"Ran {pass_num} constant folding passes")
        self.logger.info(
            f"Original Nodes: {init_num_nodes} | After folding: {postfold_num_nodes} | Folded {init_num_nodes - postfold_num_nodes} Nodes\n"
        )

    def mha2sha_v2(
        self,
        extract_lorav2_alpha: bool = False,
        permute_kv_cache_io: bool = False,
        key_cache_name_pattern: str = "past_key_(\d)+_in|past_key_(\d)+_out",
        value_cache_name_pattern: str = "past_value_(\d)+_in|past_value_(\d)+_out",
        m2s_head_split_map: Dict[int, int] | None = None,
        enable_validation:bool=False,
        validation_kwargs:Dict[str,Any]|None=None,
        **kwargs,
    ):
        """
        Convert Muti-Head Attention(MHA) layers in the model to Single-Head Attentions(SHA), in-place.
        Optionally, modify the encodings to align with the SHA model.

        If no MHA patterns are found, this function is a no-op.
        Using MHA2SHA v2 implementation

        Args:
            extract_lorav2_alpha: Whether to extract LoRAv2 alpha values. Defaults to False.
                                  (Only applicable for lora-v2 model)
            permute_kv_cache_io: Whether to permute key-value cache inputs/outputs.
                                                Defaults to False.
            key_cache_name_pattern: Pattern for key cache tensor names.
            value_cache_name_pattern: Pattern for value cache tensor names.
            m2s_head_split_map: Mapping for splitting multi-head attention to single-head attention.
                                Defaults to None. Key is the head size of mha, value is the corresponding
                                head size of sha.
                                e.g. "{25:1,128:8}" means split head=25 into head=1 and split head=128
                                into head=8.
                                     "{-1:1}" means split any head size into head=1
            enable_validation: Whether to verify the generated ONNX model with ONNX Runtime. Defaults to False.
            validation_kwargs: extra arguments for validation, right now we support:
                - input_raw_list_path: Path of raw input list for verification (same format for qairt-quantizer).
                                Defaults to None. If not provided, random inputs will be used for verification.
                - input_raw_base_dir: Base directory for raw input files. Defaults to None.

        Returns:
            Detailed information of the generated SHA model
        """
        named_encodings = {}
        named_safetensors = {}

        for lora in self.lora_adapters:
            enc = lora.encodings
            if enc:
                named_encodings[lora.name] = enc
            if lora.weights:
                named_safetensors[lora.name] = lora.weights

        base_enc_name = None
        if self.encodings:
            base_enc_name = "base"
            while base_enc_name in named_encodings:
                base_enc_name += "_"
            named_encodings[base_enc_name] = self.encodings

        # used only for extract-lorav2-alpha
        if key_cache_name_pattern is None:
            key_cache_name_pattern = "past_key_(\d)+_in|past_key_(\d)+_out"
        if value_cache_name_pattern is None:
            value_cache_name_pattern = "past_value_(\d)+_in|past_value_(\d)+_out"

        if validation_kwargs is None:
            validation_kwargs = {}

        with tempfile.TemporaryDirectory(dir=os.environ.get("QAIRT_TMP_DIR", None)) as tmpdirname:
            ws=os.path.join(tmpdirname, "ws_mha2sha_v2")
            os.makedirs(ws, exist_ok=True)
            sha_out_info = apply_mha2sha_optimization_in_memory(
                onnx_proto=self.model,
                named_encodings=named_encodings,
                named_safetensors=named_safetensors,
                updatable_tensors=self.lora_tensor_names,
                extract_lorav2_alpha=extract_lorav2_alpha,
                permute_kv_cache_io=permute_kv_cache_io,
                key_cache_name_pattern=key_cache_name_pattern,
                value_cache_name_pattern=value_cache_name_pattern,
                m2s_head_split_map=m2s_head_split_map,
                base_dir=self.model_base_dir,
                ws=ws,
                enable_validation=enable_validation,
                input_raw_list_path=validation_kwargs.get("input_raw_list_path", None),
                input_raw_base_dir=validation_kwargs.get("input_raw_base_dir", None),
            )

        self._graph = GraphManager(sha_out_info["onnx_proto"])

        if base_enc_name is not None:
            self._aimet_encodings = AimetEncodingsFactory.from_dict(
                sha_out_info["named_encodings"][base_enc_name]
            )

        sha_lora_adapters = []
        for lora in self.lora_adapters:
            sha_lora_adapters.append(
                LoraAdapter(
                    name=lora.name,
                    weights=sha_out_info["named_safetensors"][lora.name],
                    encodings=sha_out_info["named_encodings"][lora.name],
                )
            )

        self._graph = GraphManager(sha_out_info["onnx_proto"])
        self.lora_adapters = sha_lora_adapters
        self.lora_tensor_names = sha_out_info["updatable_tensors"]

        self.tracing_info = {
            # the complete and detailed tracing information
            "tracing_info": sha_out_info["tracing_info"],
            # the consolidated tracing information
            "merged_tracing_info": sha_out_info["merged_tracing_info"],
        }

        # set special lora input, including lora alpha values that extracted from the model
        self.special_sha_inputs = sha_out_info["special_inputs"]

    def mha2sha(self, **kwargs):
        """
        Convert Muti-Head Attention(MHA) layers in the model to Single-Head Attentions(SHA), in-place.
        Optionally, modify the encodings to align with the SHA model.

        If no MHA patterns are found, this function is a no-op
        """
        self.mha2sha_v1()

    def mha2sha_v1(self, **kwargs):
        node_idx = {node.name: idx for idx, node in enumerate(self._graph.nodes)}

        i = 0
        mha_patterns = []
        while i < len(self._graph.nodes):
            node = self._graph.nodes[i]

            mha = MHAPattern(self._graph)

            if mha.capture(node):
                mha_patterns.append(mha)

                i = node_idx[mha.mha.qkv.name]

            i += 1

        for idx, mha in enumerate(mha_patterns):
            self.logger.debug(f"Layer {idx} - QK MatMul: {mha.mha.qk.name}, QKV MatMul: {mha.mha.qkv.name}")

        self.logger.info(f"Found {len(mha_patterns)} Multi-Head Attention patterns")

        if mha_patterns:
            for mha in tqdm(
                mha_patterns,
                desc="MHA Pattern",
                total=len(mha_patterns),
                bar_format="{desc}: {n_fmt}/{total_fmt} |{bar} [ETA: {remaining}s,{rate_inv_fmt}{postfix}]",
            ):
                mha.replace(None)

        model_tensors = set()

        for node in self.model.graph.node:
            model_tensors.update(node.input)
            model_tensors.update(node.output)

        if self._aimet_encodings:
            if self.lora_adapters:
                self.logger.debug("Mapping base model encodings to SHA")
            else:
                self.logger.debug("Mappping encodings to SHA encodings")

            for original, slices in self._graph.tensor_mapping.items():
                if self._aimet_encodings:
                    self._aimet_encodings._map_slices(original, slices)
                    if original not in model_tensors:
                        self._aimet_encodings.delete(original)

            # HACK: Begin - Fill in missing encodings for Clip
            # Reference - https://github.qualcomm.com/chuaqin/EasyCompile/blob/main/compile_utils.py#150
            for node in self._graph.nodes:
                if node.op == "Clip" and self._aimet_encodings.get(node.outputs[0].name) is None:
                    if parent_enc := self._aimet_encodings.get(node.inputs[0].name):
                        try:
                            clip_min = self._graph.get_tensor_value(node.inputs[1])
                            clip_max = self._graph.get_tensor_value(node.inputs[2])
                        except ValueError:  # Unable to find const value
                            continue

                        # NOTE: Copied from above link
                        tensor_enc = copy.deepcopy(parent_enc)
                        max_bound = (2 ** (tensor_enc["bw"] - 1) + np.array(tensor_enc["offset"])) * (
                            np.array(tensor_enc["scale"])
                        )
                        min_bound = (np.array(tensor_enc["offset"])) * (np.array(tensor_enc["scale"]))

                        new_scale = (max_bound - min_bound) / (2 ** (tensor_enc["bw"] - 1))
                        new_offset = (min_bound / new_scale).round().astype(np.int64)
                        tensor_enc["offset"] = new_offset.tolist()
                        tensor_enc["scale"] = new_scale.tolist()

                        self._aimet_encodings._set_activation_encodings(node.outputs[0].name, tensor_enc)

            # HACK: End

        if self.lora_adapters:
            self.logger.debug("Mapping lora adapters to SHA model")
            for original, slices in self._graph.tensor_mapping.items():
                # Map LoRA adapters from MHA to SHA
                for adapter in self.lora_adapters:
                    adapter.map(original, slices, model_tensors)

                # Update the list of updatable tensors
                if original in self.lora_tensor_names:
                    self.lora_tensor_names.extend(slices)

                    if original not in model_tensors:
                        self.lora_tensor_names.remove(original)

    def split_v1(
        self, *, num_splits: int, split_embedding: bool = False, split_lm_head: bool = False, **kwargs
    ) -> list["OnnxModel"]:
        """Splits the given ONNX model into multiple smaller models

        Args:
            num_splits: The number of splits to be made
            split_embedding: If True, splits the embeddings. Default is False
            split_lm_head: If True, splits the language model head. Default is False

        Returns:
            A list of split OnnxModel encapsulating objects
        """
        if num_splits <= 1:
            return [self]

        def _get_split_encodings(_aimet_encodings, _tensors):
            split_encodings = AimetEncodingsFactory.from_version(_aimet_encodings.version)
            for tensor in _tensors:
                original_enc = _aimet_encodings.get(tensor)

                if original_enc is not None:
                    if tensor in _aimet_encodings.param_enc:
                        split_encodings._set_param_encodings(tensor, original_enc)
                    else:
                        split_encodings._set_activation_encodings(tensor, original_enc)

            split_encodings = split_encodings.encodings

            return split_encodings

        self.fold_constants()
        splits = split_onnx(
            self.model,
            num_splits=num_splits,
            split_embedding=split_embedding,
            split_lm_head=split_lm_head,
            **kwargs,
        )

        # Fill missing encodings of boundary tensors
        if self._aimet_encodings:
            for idx in range(len(splits) - 1):
                split_outputs = set(o.name for o in splits[idx].graph.output)
                next_split_inputs = set(i.name for i in splits[idx + 1].graph.input)

                for boundary_tensor in split_outputs & next_split_inputs:
                    enc = self._aimet_encodings.get(boundary_tensor)

                    if not enc:
                        self.logger.info(f"Encodings missing for boundary tensor {boundary_tensor}")

                        try:
                            start_tensor = self._graph.graph.tensors()[boundary_tensor]
                        except KeyError:
                            self.logger.warning(f"Unable to fill missing encodings for {boundary_tensor}")
                        else:
                            while not enc:
                                try:
                                    start_tensor = start_tensor.i()
                                except IndexError:
                                    self.logger.warning(
                                        f"Unable to fill missing encodings for {boundary_tensor}"
                                    )
                                    break
                                enc = self._aimet_encodings.get(start_tensor.name)

                            if enc:
                                self.logger.info(
                                    f"Copied the encodings of {start_tensor.name} to {boundary_tensor}"
                                )
                                self._aimet_encodings._set_activation_encodings(boundary_tensor, enc.copy())

        onnx_models: List[OnnxModel] = []

        for _split in splits:
            split_model_tensors = set()

            for node in _split.graph.node:
                split_model_tensors.update(node.input)
                split_model_tensors.update(node.output)

            if self._aimet_encodings:
                split_encodings = _get_split_encodings(self._aimet_encodings, split_model_tensors)
            else:
                split_encodings = None

            split_lora_adapters = []

            for adapter in self.lora_adapters:
                split_weights = {
                    name: weight for name, weight in adapter.weights.items() if name in split_model_tensors
                }
                if split_weights:
                    # Encodings are optional
                    if adapter._aimet_encodings:
                        split_lora_encodings = _get_split_encodings(
                            adapter._aimet_encodings, split_model_tensors
                        )
                    else:
                        split_lora_encodings = None
                    split_lora_adapters.append(LoraAdapter(adapter.name, split_weights, split_lora_encodings))

            split_lora_tensor_names = [
                tensor_name for tensor_name in self.lora_tensor_names if tensor_name in split_model_tensors
            ]

            onnx_models.append(
                OnnxModel(
                    model=_split,
                    encodings=split_encodings,
                    lora_adapters=split_lora_adapters,
                    lora_tensor_names=split_lora_tensor_names,
                )
            )

        return onnx_models

    def split(
        self, *, num_splits: int, split_embedding: bool = False, split_lm_head: bool = False, **kwargs
    ) -> list["OnnxModel"]:
        """Splits the given ONNX model into multiple smaller models

        Args:
            num_splits: The number of splits to be made
            split_embedding: If True, splits the embeddings. Default is False
            split_lm_head: If True, splits the language model head. Default is False

        Returns:
            A list of split OnnxModel encapsulating objects
        """

        named_encodings = {}
        named_safetensors = {}

        for lora in self.lora_adapters:
            enc = lora.encodings
            if enc:
                named_encodings[lora.name] = enc
            if lora.weights:
                named_safetensors[lora.name] = lora.weights

        base_enc_name = None
        if self.encodings:
            base_enc_name = "base"
            while base_enc_name in named_encodings:
                base_enc_name += "_"
            named_encodings[base_enc_name] = self.encodings

        optimizer = GraphOptimizer(
            self.model,
            named_encodings=named_encodings,
            named_safetensors=named_safetensors,
            updatable_tensors=self.lora_tensor_names,
        )

        splits = optimizer.split(
            num_splits=num_splits, split_embedding=split_embedding, split_lm_head=split_lm_head, **kwargs
        )

        split_onnx_models = []

        for split_info in splits:
            split_model = split_info["onnx_proto"]

            split_encodings = {}
            if base_enc_name is not None:
                split_encodings = split_info["named_encodings"][base_enc_name]

            split_lora_tensor_names = split_info["updatable_tensors"]

            split_lora_adapters = []

            # Only if this split has lora tensors, construct LoRA adapters for it
            if split_lora_tensor_names:
                for lora in self.lora_adapters:
                    split_lora_adapters.append(
                        LoraAdapter(
                            name=lora.name,
                            weights=split_info["named_safetensors"][lora.name],
                            encodings=split_info["named_encodings"][lora.name],
                        )
                    )

            split_onnx_models.append(
                OnnxModel(split_model, split_encodings, split_lora_adapters, split_lora_tensor_names)
            )

        return split_onnx_models
