# =============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import json
from pathlib import Path
from shutil import copyfile
import glob
import os

from qti.aisw.core.model_level_api.backend.backend import Backend
from qti.aisw.core.model_level_api.target.x86 import X86Target
from qti.aisw.tools.core.utilities.devices.api.device_definitions import DevicePlatformType


class HtpMcpBackend(Backend):
    def __init__(self, target=None, config_file=None, config_dict=None):
        super().__init__(target)
        self._schematic_bins = None
        if target is None:
            self.target = X86Target()

        if config_file:
            with open(config_file, 'r') as f:
                self._config = json.load(f)

        if config_dict:
            self._config.update(config_dict)

    @property
    def backend_library(self):
        return 'libQnnHtpMcp.so'

    @property
    def backend_extensions_library(self):
        return 'libQnnHtpMcpNetRunExtensions.so'

    def get_required_artifacts(self, sdk_path):
        return []

    def get_profiling_artifacts(self):
        artifacts = None
        if self._schematic_bins:
            artifacts = [Path(schematic_bin).absolute() for schematic_bin in self._schematic_bins]
        return artifacts

    def clear_profiling_artiacts(self):
        self._schematic_bins = None

    def before_generate_hook(self, temp_directory, sdk_path):
        mcp_elf_path = Path(sdk_path, 'lib', 'hexagon-v68', 'unsigned', 'libQnnHtpMcpV68.elf')
        if not mcp_elf_path.exists():
            raise FileNotFoundError(f"Could not find HTP MCP elf file {mcp_elf_path}")

        copyfile(mcp_elf_path, Path(temp_directory, 'network.elf'))

    def after_generate_hook(self, temp_directory, sdk_path):
        if self.target.target_platform_type == DevicePlatformType.X86_64_LINUX:
            self._schematic_bins = glob.glob(f"{temp_directory}{os.sep}*schematic.bin")
