# ==============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import dataclasses
import functools
import hashlib
import os
import pathlib
import shutil
import tempfile
from typing import Any, List, Optional, Union

import onnx
import yaml  # type: ignore

from qairt.api.common.backends.htp.config import HtpDeviceConfig, HtpGraphConfig
from qairt.api.compiled_model import CompiledModel
from qairt.api.compiler._compile import compile
from qairt.api.compiler.config import CompileConfig, CompilerModes
from qairt.api.configs.common import BackendType, PerfProfile
from qairt.api.configs.device import SocDetails
from qairt.api.converter._convert import convert
from qairt.api.converter.converter_config import CalibrationConfig, ConverterConfig
from qairt.api.model import Model
from qairt.api.transforms._transform import transform
from qairt.api.transforms.model_transformer_config import (
    ARn_ContextLengthConfig,
    ModelTransformerConfig,
    QuantizationStage,
)
from qairt.gen_ai_api.builders.gen_ai_utils import get_tensor_values
from qairt.gen_ai_api.configs.builder_transformer_config import BuilderTransformerConfig
from qairt.modules.lora.lora_config import (
    LoraBuilderInputConfig,
    LoraBuilderOutputConfig,
    UseCaseOutputConfig,
    load_use_case_config,
    serialize_lora_adapter_weight_config,
    serialize_lora_importer_config,
)
from qairt.modules.lora.lora_module import build_lora_graph
from qairt.utils import loggers
from qairt.utils.exceptions import InvalidCacheError
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.onnx_model import ExportedFiles, ExportedUseCase
from qti.aisw.tools.core.utilities.framework.frameworks.onnx.transform.arx2arn.ar_cl_updater import (
    set_ar,
    set_context_length,
)

_logger = loggers.get_logger(name=__name__)


def _hash(source_artifact_path: str | os.PathLike, serialized_configuration: dict[str, Any]):
    return hashlib.sha256(f"{source_artifact_path}{serialized_configuration}".encode()).hexdigest()


def _move(source: pathlib.Path, dest: pathlib.Path) -> pathlib.Path:
    destination = dest / source.name
    if destination.exists():
        os.remove(destination)
    path = shutil.move(source, destination)
    assert isinstance(path, pathlib.Path)
    return path


def cache_build_lora_graph(func):
    @functools.wraps(func)
    def wrapper(
        self,
        lora_config: LoraBuilderInputConfig,
        output_dir: str | os.PathLike,
    ) -> LoraBuilderOutputConfig:
        def _all_patterns_exist(directory: pathlib.Path, patterns: List[str]) -> bool:
            for pattern in patterns:
                if not any(directory.glob(pattern)):
                    return False
            return True

        cache_dir = getattr(self, "cache_dir")
        if not cache_dir:
            return func(self, lora_config, output_dir)

        concurrency_names = []

        if lora_config.lora_config_path:
            # Load YAML file
            with open(lora_config.lora_config_path, "r") as f:
                data = yaml.safe_load(f)
                concurrency_names = [uc["name"] for uc in data.get("use-case", [])]
            hash = _hash(lora_config.lora_config_path, data)
        else:
            if lora_config.lora_config_obj is not None:
                lora_config_hash_obj = lora_config.lora_config_obj.model_dump()
                concurrency_names = [uc["name"] for uc in lora_config_hash_obj.get("use_cases", [])]
            else:
                raise ValueError("lora_config_obj is None")

            hash = _hash("", lora_config_hash_obj)

        # Build expected file patterns
        # NOTE: The "base" case does not have a safetensors file,
        # hence is excluded from the expected pattern
        expected_patterns = (
            [
                "*.onnx",
                "*.data",
                "*.yaml",
                "*.txt",
            ]
            + [f"{name}.safetensors" for name in concurrency_names if name != "base"]
            + [f"{name}_encodings.json" for name in concurrency_names]
        )

        working_directory = pathlib.Path(cache_dir) / str(hash)
        # Check if all expected files exist
        output_directory = working_directory / "lora_model_creator_output"

        if output_directory.exists() and _all_patterns_exist(output_directory, expected_patterns):
            _logger.debug(f"Cache Hit: Using cached files from {output_directory}")

            use_case_outputs: List[UseCaseOutputConfig] = load_use_case_config(
                pathlib.Path(output_directory) / "lora_importer_config.yaml"
            )

            lora_tensor_names_path = str(output_directory / "lora_tensor_names.txt")

            base_model_artifacts = {
                "onnx": str(next(output_directory.glob("*.onnx"))),
                "data": str(next(output_directory.glob("*.data"))),
                "encodings": str(output_directory / "base_encodings.json"),
            }

            lora_output_config: LoraBuilderOutputConfig = LoraBuilderOutputConfig(
                use_case=use_case_outputs,
                lora_tensor_names=lora_tensor_names_path,
                base_model_artifacts=base_model_artifacts,
            )

            return lora_output_config

        _logger.debug("Cache Miss: Running build_lora_graph")
        lora_output_config = func(self, lora_config, output_dir)

        os.makedirs(working_directory, exist_ok=True)

        # Clear the working directory if it already contains files
        for existing_file in pathlib.Path(working_directory).glob("*"):
            if existing_file.is_file():
                existing_file.unlink()
            elif existing_file.is_dir():
                shutil.rmtree(existing_file)

        shutil.move(
            pathlib.Path(output_dir) / "lora_model_creator_output",
            working_directory / "lora_model_creator_output",
        )

        # Update lora_tensor_names
        lora_output_config.lora_tensor_names = str(output_directory / lora_output_config.lora_tensor_names)
        # Update base_model_artifacts
        lora_output_config.base_model_artifacts = {
            key: str(output_directory / value)
            for key, value in lora_output_config.base_model_artifacts.items()
        }
        # Update each use case
        for use_case in lora_output_config.use_case:
            if use_case.model is not None:
                use_case.model = str(output_directory / use_case.model)
            use_case.lora_weights = str(output_directory / use_case.lora_weights)
            if use_case.encodings is not None:
                use_case.encodings = str(output_directory / use_case.encodings)
            if use_case.output_path is not None:
                use_case.output_path = str(output_directory / use_case.output_path)

        # Serialize updated config to YAML
        yaml_path = output_directory / "lora_importer_config.yaml"

        serialize_lora_importer_config(
            lora_uc_output_config=lora_output_config.use_case,
            yaml_path=str(yaml_path),
            base_dir=str(output_directory),
        )

        return lora_output_config

    return wrapper


def cache_transform(func):
    @functools.wraps(func)
    def wrapper(
        self,
        model_path: str | os.PathLike,
        config: BuilderTransformerConfig,
        encodings_path: str | os.PathLike,
        lora_output_config: Optional[LoraBuilderOutputConfig] = None,
        **kwargs,
    ) -> list[list[ExportedFiles]]:
        cache_dir = getattr(self, "cache_dir")
        if not cache_dir:
            return func(
                self,
                model_path,
                config,
                encodings_path,
                lora_output_config,
                **kwargs,
            )

        combined = {**config.model_dump(), "encodings_path": encodings_path}
        hash = _hash(model_path, combined)
        arn_cl_config = config.model_transformer_config.arn_cl_options

        cl = arn_cl_config.context_length
        arns = arn_cl_config.auto_regression_number

        n = config.model_transformer_config.split_model.num_splits

        working_directory = pathlib.Path(cache_dir) / str(hash)

        # Handle multiple AR cases
        cached_arn_split_files: list[list[ExportedFiles]] = []
        try:
            for splits_for_arn in arns:
                exported_files_list: list[ExportedFiles] = []
                for exported_files_for_split in range(1, n + 1):
                    pattern = f"ar{splits_for_arn}_cl{cl}_{exported_files_for_split}_of_{n}"
                    exported_files = ExportedFiles(
                        onnx_path=working_directory / pattern / f"{pattern}.onnx",
                        data_path=working_directory / pattern / f"{pattern}.data",
                    )
                    if encodings_path:
                        exported_files.encodings_path = working_directory / pattern / f"{pattern}.encodings"
                    if lora_output_config:
                        if (
                            exported_files_for_split == 1
                            and config.model_transformer_config.split_model.split_embedding
                        ):
                            pass
                        elif (
                            exported_files_for_split == n
                            and config.model_transformer_config.split_model.split_lm_head
                        ):
                            pass
                        else:
                            exported_files.lora_tensor_names = (
                                working_directory / pattern / "lora_tensor_names.txt"
                            )
                            exported_files.lora_importer_config = (
                                working_directory / pattern / "lora_importer_config.yaml"
                            )
                            exported_files.use_cases = [
                                ExportedUseCase(
                                    name=use_case.name,
                                    safetensors=working_directory / pattern / use_case.lora_weights,
                                    # we store as .encodings not _encodings.json
                                    encodings=working_directory / pattern / f"{use_case.name}.encodings",
                                )
                                for use_case in lora_output_config.use_case
                            ]

                    exported_files_list.append(exported_files)
                cached_arn_split_files.append(exported_files_list)
            _logger.debug(f"Cache Hit: continuing with cached artifacts from {working_directory}")
            return cached_arn_split_files
        except ValueError:
            _logger.debug(f"Cache Miss: Files not matching patterns for all ARs")

        split_transformed_files: list[list[ExportedFiles]] = func(
            self,
            model_path,
            config,
            encodings_path,
            lora_output_config,
            **kwargs,
        )

        os.makedirs(working_directory, exist_ok=True)

        # Move the transformed files to the cache
        for splits in split_transformed_files:
            for exported_files in splits:
                original_dir = exported_files.onnx_path.parent
                base = exported_files.onnx_path.parent.name
                os.makedirs(working_directory / base, exist_ok=True)

                exported_files.onnx_path = _move(exported_files.onnx_path, working_directory / base)
                exported_files.data_path = _move(exported_files.data_path, working_directory / base)
                if exported_files.encodings_path:
                    exported_files.encodings_path = _move(
                        exported_files.encodings_path, working_directory / base
                    )
                if exported_files.lora_tensor_names:
                    exported_files.lora_tensor_names = _move(
                        exported_files.lora_tensor_names, working_directory / base
                    )
                use_cases = []
                for uc in exported_files.use_cases:
                    uc.safetensors = _move(uc.safetensors, working_directory / base)
                    if uc.encodings:
                        uc.encodings = _move(uc.encodings, working_directory / base)
                    use_case = {
                        "name": uc.name,
                        "model_name": str(exported_files.onnx_path),
                        "lora_weights": str(uc.safetensors),
                    }
                    if uc.encodings:
                        use_case["quant_overrides"] = str(uc.encodings)
                    use_cases.append(use_case)

                # this moves the file and updates exported_files but the contents of this file
                # is absolute paths, which have not been updated.
                if exported_files.lora_importer_config:
                    exported_files.lora_importer_config = _move(
                        exported_files.lora_importer_config, working_directory / base
                    )
                lora_config = {"use_case": use_cases}
                with open(str(exported_files.lora_importer_config), "w") as f:
                    f.write(yaml.dump(lora_config))

                os.rmdir(original_dir)

        return split_transformed_files

    return wrapper


def cache_compile(func):
    @functools.wraps(func)
    def wrapper(self, model: Union[Model, List[Model]], config: CompileConfig) -> CompiledModel:
        cache_dir = getattr(self, "cache_dir")
        if not cache_dir:
            return func(self, model, config)
        working_directory = None
        _cached = None
        if isinstance(model, list):
            combined_model_info = "".join(f"{m.module.path}{m.name}" for m in model)
            hash = _hash(combined_model_info, config.model_dump())
            working_directory = os.path.join(cache_dir, hash)
            _cached = os.path.join(working_directory, model[0].name + ".bin")
        else:
            hash = _hash(f"{model.module.path}{model.name}", config.model_dump())
            working_directory = os.path.join(cache_dir, hash)
            _cached = os.path.join(working_directory, model.name + ".bin")

        if os.path.exists(_cached):
            try:
                _logger.debug(f"Cache Hit: {_cached} ")
                return CompiledModel.load(_cached, compile_config=config)
            except Exception as e:
                raise InvalidCacheError(f"Invalid cache compiled: {_cached}") from e
        else:
            _logger.debug(f"Cache Miss: {_cached}")
        # weird using positional arguments for config doesn't seem to work.  VSCode sees the signature but when I execute
        # it only sees the version with one positional and the other kwargs.
        os.makedirs(pathlib.Path(working_directory), exist_ok=True)
        compiled_model = func(self, model, config)

        compiled_model.save(_cached)
        return CompiledModel.load(_cached, compile_config=config)

    return wrapper


def cache_convert(func):
    @functools.wraps(func)
    def wrapper(
        self,
        exported_files: ExportedFiles,
        config: ConverterConfig,
        calibration_config: Optional[CalibrationConfig] = None,
        **extra_args,
    ) -> Model:
        cache_dir = getattr(self, "cache_dir")
        if not cache_dir:
            return func(self, exported_files, config, calibration_config)

        if calibration_config:
            serialized_configuration = {**config.model_dump(), **calibration_config.model_dump()}
        else:
            serialized_configuration = {**config.model_dump()}

        serialized_configuration.update(extra_args)

        if exported_files.lora_importer_config:
            serialized_configuration["lora_importer_config"] = exported_files.lora_importer_config
        if exported_files.lora_tensor_names:
            serialized_configuration["lora_tensor_names"] = exported_files.lora_tensor_names

        hash = _hash(exported_files.onnx_path, serialized_configuration)

        working_directory = pathlib.Path(cache_dir) / hash
        model_base_name = pathlib.Path(exported_files.onnx_path).stem
        _dlc = os.path.join(working_directory, f"{model_base_name}.dlc")
        if os.path.exists(_dlc):
            try:
                _logger.debug(f"Cache Hit: {_dlc}")
                return Model.load(_dlc)
            except Exception as e:
                raise InvalidCacheError(f"invalid cached convert: {_dlc}") from e
        else:
            _logger.debug(f"Cache Miss: {_dlc}")

        os.makedirs(pathlib.Path(_dlc).parent, exist_ok=True)
        model = func(self, exported_files, config, calibration_config, **extra_args)
        model.save(_dlc)

        # Serialize LoRA use cases to YAML with absolute paths
        if model.lora_use_cases and isinstance(model.lora_use_cases, list):
            yaml_path = working_directory / "lora_use_cases.yaml"
            for use_case in model.lora_use_cases:
                if use_case.lora_weights:
                    use_case.lora_weights = str((working_directory / use_case.lora_weights).resolve())
                if use_case.encodings:
                    use_case.encodings = str((working_directory / use_case.encodings).resolve())

            serialize_lora_adapter_weight_config(model.lora_use_cases, str(yaml_path), str(working_directory))
            _logger.debug(f"Serialized LoRA use cases to {yaml_path}")

        return Model.load(_dlc)

    return wrapper


def cache_arcl(func):
    @functools.wraps(func)
    def wrapper(self, model_path: str | os.PathLike, config: ARn_ContextLengthConfig) -> list[ExportedFiles]:
        cache_dir = getattr(self, "cache_dir")
        if not cache_dir:
            return func(self, model_path, config)
        hash = _hash(f"{model_path}", dataclasses.asdict(config))
        working_directory = pathlib.Path(cache_dir) / hash
        existing: list[ExportedFiles] = []
        try:
            for arn in config.auto_regression_number:
                pattern = f"ar{arn}_cl{config.context_length}"
                exported_files = ExportedFiles(
                    onnx_path=working_directory / f"{pattern}.onnx",
                    data_path=working_directory / f"{pattern}.data",
                )
                existing.append(exported_files)
            _logger.debug(
                f"Cache Hit: ARn: {config.auto_regression_number} context length: {config.context_length}"
            )
            return existing
        except ValueError:
            _logger.debug(f"Cache Miss: ARn/CL conversion not found")

        newly_created: list[ExportedFiles] = func(self, model_path, config)
        os.makedirs(working_directory, exist_ok=True)
        for files in newly_created:
            original_dir = files.onnx_path.parent
            files.onnx_path = _move(files.onnx_path, working_directory)
            files.data_path = _move(files.data_path, working_directory)
            os.rmdir(original_dir)

        return newly_created

    return wrapper


class HTPMixin:
    def __init__(self, cache_dir: Optional[str | os.PathLike] = None, enable_weight_sharing: bool = True):
        """
        Initializes the HTP Mixin class. This class is utilized by the GenAIBuilderHTP class.

        Args:
            cache_dir (str | PathLike) Optional: A user-provided path that indicates the root
                directory to store artifacts for subsequent invocations.
            enable_weight_sharing (bool): Flag to enable weight sharing during model compilation.

        """
        self.cache_dir = cache_dir
        self.path_root = os.getenv("QAIRT_TMP_DIR", default=tempfile.gettempdir())
        self._calibration_config: Optional[CalibrationConfig] = None
        self._compilation_config: Optional[CompileConfig] = None
        self.enable_weight_sharing = enable_weight_sharing

    @cache_build_lora_graph
    def build_lora_graph(
        self,
        lora_config: LoraBuilderInputConfig,
        output_dir: str | os.PathLike,
    ) -> LoraBuilderOutputConfig:
        """
        Builds a max rank, concatenated LoRA configuration graph by processing the input
        configuration through a mapper and model creator pipeline.

        Args:
        lora_config (LoraBuilderInputConfig): The input LoRA configuration, either as a
        configuration object or a path to a configuration file.
        output_dir (str | os.PathLike): Directory to store final outputs.

        Returns:
        LoraBuilderOutputConfig: The finalized LoRA configuration object.
        """

        return build_lora_graph(lora_config, output_dir=output_dir)

    @cache_arcl
    def ar_cl_conversion(
        self, model_path: str | os.PathLike, arn_cl_config: ARn_ContextLengthConfig
    ) -> list[ExportedFiles]:
        # Load the ONNX model and adjust context length
        onnxmodel = onnx.load(model_path)
        cl, arn = get_tensor_values(onnxmodel)
        if cl != arn_cl_config.context_length:
            set_context_length(onnxmodel, arn_cl_config.context_length, in_place=True)
            _logger.debug(f"Resizing context length Old: {cl} new: {arn_cl_config.context_length}")
        converted: list[ExportedFiles] = []
        for ar in arn_cl_config.auto_regression_number:
            # Create a separate ONNX model for each AR
            onnxmodel_ar = onnx.ModelProto()
            onnxmodel_ar.CopyFrom(onnxmodel)
            set_ar(onnxmodel_ar, ar, in_place=True)
            _logger.debug(f"Resizing Auto regression number for AR {ar} Old: {arn} new: {ar}")

            tmp_root_dir = os.getenv("QAIRT_TMP_DIR", default=tempfile.gettempdir())
            temp_working_dir = pathlib.Path(tempfile.mkdtemp(prefix="temp_working_dir_", dir=tmp_root_dir))
            # Save the model if required
            proposed_name = f"ar{ar}_cl{arn_cl_config.context_length}"
            arn_model_path = pathlib.Path(temp_working_dir) / f"{proposed_name}.onnx"
            arn_model_path.parent.mkdir(parents=True, exist_ok=True)
            if not arn_model_path.exists():
                onnx.save(
                    onnxmodel_ar,
                    arn_model_path,
                    save_as_external_data=True,
                    location=f"{proposed_name}.data",
                )
            exported_files = ExportedFiles(
                onnx_path=arn_model_path, data_path=pathlib.Path(temp_working_dir) / f"{proposed_name}.data"
            )
            converted.append(exported_files)
        return converted

    @cache_transform
    def transform(
        self,
        model_path: str | os.PathLike,
        config: BuilderTransformerConfig,
        encodings_path: Optional[str | os.PathLike] = None,
        lora_output_config: Optional[LoraBuilderOutputConfig] = None,
    ) -> list[list[ExportedFiles]]:
        """
        Performs quantization transformations to the model.

        Args:
            model_path (str | PathLike): The path to the model file to transform.
            config (BuilderTransformerConfig): The configuration for the transformation process.
            encodings_path (str | PathLike): The path to the encodings file.

        Returns:
            Tuple[List[pathlib.Path], ...]:
                - A tuple of lists of paths for the transformed model files for multiple or single ARs.

        .. note::

            if config.model_transformer_config.split_model.num_splits > 1, this will return
            a corresponding number of subdivided models
        """
        arn_cl_config = config.model_transformer_config.arn_cl_options
        transformed_splits: list[list[ExportedFiles]] = []
        arn_cl_converted_files = self.ar_cl_conversion(model_path, arn_cl_config)
        for exported_files in arn_cl_converted_files:
            kwargs: dict[str, Any] = {
                "split_model": config.model_transformer_config.split_model,
                "mha_config": config.model_transformer_config.mha_config,
            }

            if lora_output_config and lora_output_config.use_case:
                kwargs["lora_tensor_names_path"] = lora_output_config.lora_tensor_names
                lap = {
                    "use_case": [
                        {
                            "name": use_case.name,
                            "model_name": use_case.model,
                            "lora_weights": use_case.lora_weights,
                            "quant_overrides": use_case.encodings if use_case.encodings else None,
                        }
                        for use_case in lora_output_config.use_case
                    ]
                }
                kwargs["lora_adapters_path"] = lap

            splits = transform(
                model=exported_files.onnx_path,
                backend=config.backend,
                quantization_stage=QuantizationStage.POST_QUANT,
                encodings=encodings_path,
                **kwargs,
            )
            proposed_name = exported_files.onnx_path.stem
            exported_splits: list[ExportedFiles] = [
                split.export(
                    pathlib.Path(self.path_root) / f"{proposed_name}_{idx + 1}_of_{len(splits)}",
                    f"{proposed_name}_{idx + 1}_of_{len(splits)}",
                )
                for idx, split in enumerate(splits)
            ]
            transformed_splits.append(exported_splits)

            assert len(exported_splits) == config.model_transformer_config.split_model.num_splits
        assert len(transformed_splits) == len(
            config.model_transformer_config.arn_cl_options.auto_regression_number
        )

        return transformed_splits

    def set_transformation_options(
        self,
        config: ModelTransformerConfig,
    ):
        """
        Provide instructions for applying transformations to the source framework model. This method updates
        the transformation configuration for the model, which will be used when transforming the model
        for execution on the target device.

        Args:
            config (ModelTransformerConfig): The transformation configuration options.

        """
        self._transformation_config = BuilderTransformerConfig(
            model_transformer_config=config, backend=BackendType.HTP
        )

    @cache_convert
    def convert(
        self,
        exported_files: ExportedFiles,
        config: ConverterConfig,
        calibration_config: Optional[CalibrationConfig] = None,
        **extra_args,
    ) -> Model:
        """
        Converts an ONNX model to a QAIRT model.

        Args:
            exported_files (ExportedFiles): Collection of file paths and configuration data required for model conversion.
            config (ConverterConfig): Configuration for the conversion process.
            calibration_config (Optional[CalibrationConfig]): Configuration for the calibration process.
                                                              Defaults to None.
            **extra_args: Extra keyword arguments for conversion options.
                      See :class:`qairt.api.converter.converter_config.ConverterConfig` for details.

                      Additional supported arguments include:
                        - `lora_importer_config`: Optional configuration for LoRA Importer's
                            apply_lora_updates
                        - `lora_tensor_names`: Optional file specifying a list of tensor names
                            that should be updatable.
                        - `quant_updatable_mode`: Mode for quant-updatable tensors.

        Returns:
            Model: The converted QAIRT model.

        Raises:
            FileNotFoundError: If the encodings file corresponding to the ONNX model is not found.
        """
        onnx_file_path = exported_files.onnx_path
        split_encodings = exported_files.encodings_path
        lora_importer_config = exported_files.lora_importer_config
        lora_tensor_names = exported_files.lora_tensor_names

        if split_encodings and not split_encodings.exists():
            _logger.warning(f"Could not find encodings file {split_encodings}")
            matches = list(pathlib.Path(onnx_file_path).parent.glob("*.encodings"))
            if len(matches) == 1:
                _logger.warning(f"Expected {split_encodings} but proceeding with {matches[0]}")
                split_encodings = matches[0]
            else:
                split_encodings = None

        convert_args = config.model_dump()
        convert_args["lora_tensor_names"] = lora_tensor_names
        convert_args["lora_importer_config"] = lora_importer_config

        # Extract quant_updatable_mode from extra_args if present
        quant_updatable_mode = extra_args.get("quant_updatable_mode")
        if quant_updatable_mode is not None:
            convert_args["quant_updatable_mode"] = quant_updatable_mode

        model: Model = convert(
            onnx_file_path,
            split_encodings,
            calibration_config,
            **convert_args,
        )

        return model

    def set_conversion_options(
        self, config: ConverterConfig, calibration_config: Optional[CalibrationConfig] = None
    ):
        """
        Provide instructions for converting the framework model into the internal
        representation needed to run on Qualcomm devices.

        This method updates the conversion configuration for the model, which will be
        used when converting the model for execution on the target device.

        Args:
            config (ConverterConfig): A configuration object containing conversion options.
            calibration_config (Optional[CalibrationConfig]): Configuration for the calibration process.

        """
        self._conversion_config = config
        self._calibration_config = calibration_config
        if not self._conversion_config.input_tensor_config:
            self._conversion_config.input_tensor_config = []
        if not self._conversion_config.output_tensor_config:
            self._conversion_config.output_tensor_config = []

    @cache_compile
    def compile(self, model: Union[Model, List[Model]], config: CompileConfig) -> CompiledModel:
        """
        Compile a converted QAIRT model into a compiled model.

        Args:
            model (Union[Model, List[Model]]): The converted QAIRT model(s) to compile.
            config (CompileConfig): The compilation configuration.

        Returns:
            CompiledModel: The compiled model.
        """
        return compile(model, config=config)

    def set_compilation_options(self, config: CompileConfig):
        """
        Sets the compilation options for the model.

        This method updates the compilation configuration for the model, which will be used when compiling
        the model for execution on the target device.

        Args:
            config (CompileConfig): The compilation configuration options.

        """
        if self._compilation_config:
            _logger.warning("Overriding existing compilation config (maybe set via set_targets)")
        self._compilation_config = config

        if self.enable_weight_sharing:
            _logger.debug("Setting compilation mode to WEIGHT_SHARING.")
            self._compilation_config.set_mode(CompilerModes.WEIGHT_SHARING.value)
        else:
            _logger.debug("Weight sharing is disabled.")

    # TODO: Support multiple chipsets, which would result in building a multiple context binaries
    # and let the resulting container build executors for the appropriate target.
    def set_targets(self, soc_details: list[Union[SocDetails, str]]):
        """
        Creates a compile config for the target chipset(s)
        This method sets the compilation configuration for the target chipset with the corresponding
        SoC details, graph custom configurations, and device custom configurations.

        Args:
            soc_details(SocDetail | str) Device specification to use for compilation. Can be specified as
                a spec string in the form "chipset:value;dsp_arch:value;soc_model:value|...".

        Raises:
            NotImplementedError: If there is not exactly one chipset provided.
            ValueError: if a valid HtpDeviceConfig can not be created from soc_details

        """
        if len(soc_details) != 1:
            raise NotImplementedError("Current version supports targeting only one chipset.")
        config = CompileConfig(backend=BackendType.HTP, soc_details=soc_details[0], log_level="warn")
        if config.soc_details == "chipset:UNKNOWN":
            _logger.warning(
                "The chipset of the specified device was not recognized.  Please visit "
                "https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/overview.html#supported-snapdragon-devices "
                "to identify the device and specify the target chipset."
            )
        config.graph_custom_configs = [HtpGraphConfig(name="placeholder", optimization_type=3)]
        if (
            config.device_custom_configs
            and isinstance(config.device_custom_configs[0], HtpDeviceConfig)
            and config.device_custom_configs[0].cores
        ):
            for core in config.device_custom_configs[0].cores:
                core.perf_profile = PerfProfile.BURST
        else:
            raise ValueError(f"Unable to create valid DevicHtpDeviceConfig for SocDetails: {soc_details[0]}")

        self.set_compilation_options(config)
