# =============================================================================
#
#  Copyright (c) 2024 Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

from __future__ import annotations
from os import PathLike
from pathlib import Path
from typing import List, Union, Dict, Any, Optional, TYPE_CHECKING

import qti.aisw.core.model_level_api as mlapi
from qti.aisw.tools.core.modules.api.definitions.common import BackendType, Target, Model
from qti.aisw.tools.core.utilities.devices.api.device_definitions import DevicePlatformType

if TYPE_CHECKING:
    from qti.aisw.tools.core.modules.net_runner.net_runner_module import InferenceConfig
    from qti.aisw.tools.core.modules.context_bin_gen.context_bin_gen_module import GenerateConfig

_backend_type_to_mlapi_backend = {
    BackendType.CPU: mlapi.CpuBackend,
    BackendType.GPU: mlapi.GpuBackend,
    BackendType.HTP: mlapi.HtpBackend,
    BackendType.HTP_MCP: mlapi.HtpMcpBackend,
    BackendType.AIC: mlapi.AicBackend,
    BackendType.LPAI: mlapi.LpaiBackend
}


def get_supported_backends() -> List[str]:
    # creating a list from a dict returns only the keys
    return list(_backend_type_to_mlapi_backend)


def create_mlapi_target(target: Target) -> mlapi.Target:
    # handle hostname/port when support is added to model-level API
    if target.type == DevicePlatformType.ANDROID:
        device_id = target.identifier.serial_id if target.identifier else None
        hostname = target.identifier.hostname if target.identifier else None
        adb_server_port = target.identifier.port if target.identifier else None
        return mlapi.AndroidTarget(device_id=device_id,
                                   hostname=hostname,
                                   adb_server_port=adb_server_port)
    elif target.type == DevicePlatformType.X86_64_LINUX:
        return mlapi.X86Target()
    elif target.type == DevicePlatformType.LINUX_EMBEDDED:
        device_id = target.identifier.serial_id if target.identifier else None
        hostname = target.identifier.hostname if target.identifier else None
        adb_server_port = target.identifier.port if target.identifier else None
        return mlapi.OELinuxTarget(device_id=device_id,
                                    hostname=hostname,
                                    adb_server_port=adb_server_port)
    else:
        raise ValueError(f'Unknown target type: {target.type}')


def create_mlapi_backend(backend: BackendType,
                         target: Optional[Target],
                         config_file: Optional[Union[str, PathLike]] = None,
                         config_dict: Optional[Dict[str, Any]] = None) -> \
        mlapi.Backend:
    mlapi_target = create_mlapi_target(target) if target else None

    mlapi_backend_type = _backend_type_to_mlapi_backend.get(backend)
    if not mlapi_backend_type:
        raise ValueError(f'Unknown backend type: {backend}')

    if mlapi_backend_type is mlapi.CpuBackend:
        # CPU does not support backend specific configs, so skip passing config file/dict
        return mlapi_backend_type(target=mlapi_target)
    else:
        return mlapi_backend_type(target=mlapi_target,
                                  config_file=config_file,
                                  config_dict=config_dict)


def create_mlapi_model(model: Model) -> mlapi.Model:
    if model.qnn_model_library_path:
        model_type = mlapi.QnnModelLibrary
        model_path = Path(str(model.qnn_model_library_path))
    elif model.context_binary_path:
        model_type = mlapi.QnnContextBinary
        model_path = Path(str(model.context_binary_path))
    elif model.dlc_path:
        model_type = mlapi.DLC
        model_path = Path(str(model.dlc_path))
    else:
        raise ValueError(f'Unknown model type {model}')

    return model_type(name=model_path.stem, path=str(model_path))


def create_mlapi_run_config(config: 'InferenceConfig') -> mlapi.QNNRunConfig:
    config_dict = config.dict()
    config_dict.pop("op_packages", None)
    return mlapi.QNNRunConfig(**config_dict)


def create_mlapi_generate_config(config: 'GenerateConfig') -> mlapi.QNNGenerateConfig:
    config_dict = config.dict()
    config_dict.pop("op_packages", None)
    return mlapi.QNNGenerateConfig(**config_dict)
