# =============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import json
import logging
from pathlib import Path

from qti.aisw.core.model_level_api.backend.backend import Backend
from qti.aisw.core.model_level_api.workflow.workflow import WorkflowMode
from qti.aisw.core.model_level_api.target.android import AndroidTarget
from qti.aisw.core.model_level_api.target.x86 import X86Target
from qti.aisw.tools.core.utilities.devices.api.device_definitions import DevicePlatformType

logger = logging.getLogger(__name__)


class LpaiBackend(Backend):
    def __init__(self, target=None, config_file=None, config_dict=None, **kwargs):
        super().__init__(target)
        if config_file:
            with open(config_file, 'r') as f:
                self._config = json.load(f)

        if config_dict:
            self._config.update(config_dict)

    def _workflow_mode_setter_hook(self, mode):
        if self._default_target:
            if mode == WorkflowMode.INFERENCE:
                self.target = AndroidTarget()
            elif mode == WorkflowMode.CONTEXT_BINARY_GENERATION:
                self.target = X86Target()
            else:
                raise ValueError(f'Invalid workflow_mode: {mode}')

    @property
    def backend_library(self):
        return 'libQnnLpai.so'

    @property
    def backend_extensions_library(self):
        return "libQnnLpaiNetRunExtensions.so"

    def get_required_artifacts(self, sdk_root):
        artifacts = []
        if self.target.target_platform_type == DevicePlatformType.ANDROID:
            lpai_stub = Path(sdk_root, 'lib', 'aarch64-android', 'libQnnLpaiV79Stub.so')
            lpai_skel = Path(sdk_root, 'lib', 'hexagon-v79', 'unsigned', 'libQnnLpaiV79Skel_v5.so')
            artifacts = [str(lpai_stub), str(lpai_skel)]

        return artifacts
