#!/usr/bin/env python3
# =============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================
import argparse
from dataclasses import dataclass
import logging
import os
from pathlib import Path
import platform
import random
import shutil
import subprocess as sp
import sys
import tarfile
from typing import Dict, List, Optional

random.seed(23)
logger = None
DEFAULT_OUTPUT_DIR = Path.cwd() / 'lib'
CLANG = 'clang++-9' if shutil.which('clang++-9') is not None else 'clang++'
QNN_SDK_ROOT = Path(__file__).resolve().parent.parent.parent.resolve()
CONVERTER_SHARE_DIR = QNN_SDK_ROOT / 'share' / 'QNN' / 'converter'


def is_windows() -> bool:
    out = False
    try:
        out = 'win' in platform.system().lower() or 'nt' in platform.system().lower() or sys.winver
    except AttributeError:
        ...
    return out


def exec_subprocess_cmd(cmd: List[str]) -> sp.CompletedProcess:
    os.environ['QNN_SDK_ROOT'] = os.environ.get('QNN_SDK_ROOT', str(QNN_SDK_ROOT))
    result = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE, encoding='utf-8')
    if result.returncode != 0:
        # There is some genuine failure while running the subprocess command.
        logger.error(f'command : {result.args}')
        logger.error(f'rc : {result.returncode}')
        logger.error(f'stdout : {result.stdout}')
        logger.error(f'stderr : {result.stderr}')
    elif result.stderr:
        # There is not failure. Only warings are being written to sys.stdout
        logger.info(f'command : {result.args}')
        logger.info(f'rc : {result.returncode}')
        logger.info(f'stdout : {result.stdout}')
        logger.info(f'stderr : {result.stderr}')
    else:
        # No failure and no warnings.
        logger.debug(f'command : {result.args}')
        logger.debug(f'rc : {result.returncode}')
        logger.debug(f'stdout : {result.stdout}')
        logger.debug(f'stderr : {result.stderr}')
    return result


@dataclass
class ModelLibGenerationConfig:
    model_cpp: Path = None
    model_bin: Optional[Path] = None
    targets: List[str] = None
    lib_name: str = ''
    output_dir: Path = None
    clean_up: bool = True


@dataclass
class WindowsModelLibGenerationConfig(ModelLibGenerationConfig):
    obj_generator_path: Path = None
    debug_build: bool = False


class Target:
    converter_sources: List[Path] = [
        Path('jni', 'QnnModel.cpp'),
        Path('jni', 'QnnModel.hpp'),
        Path('jni', 'QnnModelPal.hpp'),
        Path('jni', 'QnnTypeMacros.hpp'),
        Path('jni', 'QnnWrapperUtils.cpp'),
        Path('jni', 'QnnWrapperUtils.hpp')
    ]

    def __init__(self, name: str = '', alias: str = ''):
        self._name = name
        self._alias = alias

    @property
    def name(self) -> str:
        return self._name

    @property
    def alias(self) -> str:
        return self._alias

    def is_available(self) -> bool:
        return all([(CONVERTER_SHARE_DIR / file).is_file() for file in Target.converter_sources])

    def _prepare_build_directory(self, build_dir: Path, config: ModelLibGenerationConfig):
        ...

    def build_target(self, build_dir: Path, config: ModelLibGenerationConfig) -> Path:
        ...


class NixTarget(Target):
    def __init__(self, name: str, alias: str):
        super().__init__(name, alias)

    def is_available(self) -> bool:
        return platform.system() == 'Linux' and super().is_available()

    def _prepare_build_directory(self, build_dir: Path, config: ModelLibGenerationConfig):
        (build_dir / 'jni').mkdir(parents=True, exist_ok=True)
        for file in Target.converter_sources:
            shutil.copy(CONVERTER_SHARE_DIR / file, build_dir / file)
        shutil.copy(CONVERTER_SHARE_DIR / 'jni' / 'linux' / 'QnnModelPal.cpp', build_dir / 'jni' / 'QnnModelPal.cpp')

        shutil.copy(config.model_cpp, build_dir / 'jni')
        if config.model_bin:
            shutil.copy(config.model_bin, build_dir / 'jni')

    def _build_cmd(self) -> str:
        return ''

    def _build_lib_dir(self) -> str:
        ...

    def build_target(self, build_dir: Path, config: ModelLibGenerationConfig) -> Path:
        self._prepare_build_directory(build_dir, config)

        if not config.lib_name.endswith('.so'):
            config.lib_name += '.so'
        if not config.lib_name.startswith('lib'):
            config.lib_name = 'lib' + config.lib_name

        exec_subprocess_cmd([f'cd {build_dir} && export QNN_MODEL_LIB_NAME={config.lib_name} && {self._build_cmd()}'])
        target_lib_path = build_dir / self._build_lib_dir() / config.lib_name
        if not target_lib_path.is_file():
            raise RuntimeError(f'Failed to compile model library for target: {self.alias}')
        output_dir = config.output_dir / self.alias
        output_dir.mkdir(parents=True, exist_ok=True)
        if (output_dir / config.lib_name).is_file():
            (output_dir / config.lib_name).unlink()
        shutil.move(target_lib_path, output_dir / config.lib_name)
        return output_dir / config.lib_name


class LinuxTarget(NixTarget):
    def __init__(self, name: str, alias:str, compiler: str, sys_root:str = None):
        super().__init__(name, alias)
        self._makefile = Path(f'Makefile.{self._name}')
        self._compiler = compiler
        self._sys_root = sys_root

    def is_available(self) -> bool:
        return (
                super().is_available()
                and (CONVERTER_SHARE_DIR / self._makefile).is_file()
        )

    def _prepare_build_directory(self, build_dir: Path, config: ModelLibGenerationConfig):
        super()._prepare_build_directory(build_dir, config)
        shutil.copy(CONVERTER_SHARE_DIR / self._makefile, build_dir / self._makefile)

    def _sys_root_arg(self) -> str:
        if self._sys_root:
            return f' --sysroot={os.path.expandvars(self._sys_root)}'
        else:
            return ''

    def _build_cmd(self) -> str:
        build_tools_error = ""
        compiler = Path(os.path.expandvars(self._compiler))
        if not (compiler.exists() or shutil.which(str(compiler))):
            build_tools_error += f"\nCould not find compiler: {self._compiler}"
        if self._sys_root and not Path(os.path.expandvars(self._sys_root)).is_dir():
            build_tools_error += f"\nCould not find sys root directory: {self._sys_root}"
        if build_tools_error:
            raise RuntimeError(f'Could not build target: {self.alias}' + build_tools_error)
        return f'make CXX="{compiler}{self._sys_root_arg()}" -f {self._makefile}'

    def _build_lib_dir(self) -> Path:
        return Path('libs', self.alias)


class Aarch64AndroidTarget(NixTarget):
    converter_sources: List[Path] = [
        Path('jni', 'Android.mk'),
        Path('jni', 'Application.mk')
    ]

    def __init__(self):
        super().__init__(name='arm64-v8a', alias='aarch64-android')

    def is_available(self) -> bool:
        return (
                super().is_available()
                and all([(CONVERTER_SHARE_DIR / file).is_file() for file in Aarch64AndroidTarget.converter_sources])
        )

    def _prepare_build_directory(self, build_dir: Path, config: ModelLibGenerationConfig):
        super()._prepare_build_directory(build_dir, config)
        for file in Aarch64AndroidTarget.converter_sources:
            shutil.copy(CONVERTER_SHARE_DIR / file, build_dir / file)

    def _build_cmd(self) -> str:
        return f'export QNN_ANDROID_APP_ABIS={self._name} && ndk-build; unset QNN_ANDROID_APP_ABIS'

    def _build_lib_dir(self) -> Path:
        return Path('libs', self.name)

class ArmAndroidTarget(NixTarget):
    converter_sources: List[Path] = [
        Path('jni', 'Android.mk'),
        Path('jni', 'Application.mk')
    ]

    def __init__(self):
        super().__init__(name='armeabi-v7a', alias='arm-android')

    def is_available(self) -> bool:
        return (
                super().is_available()
                and all([(CONVERTER_SHARE_DIR / file).is_file() for file in ArmAndroidTarget.converter_sources])
        )

    def _prepare_build_directory(self, build_dir: Path, config: ModelLibGenerationConfig):
        super()._prepare_build_directory(build_dir, config)
        for file in ArmAndroidTarget.converter_sources:
            shutil.copy(CONVERTER_SHARE_DIR / file, build_dir / file)

    def _build_cmd(self) -> str:
        return f'export QNN_ANDROID_APP_ABIS={self._name} && ndk-build; unset QNN_ANDROID_APP_ABIS'

    def _build_lib_dir(self) -> Path:
        return Path('libs', self.name)


class WindowsTarget(Target):
    converter_sources: List[Path] = [
        Path('CMakeLists.txt'),
        Path('jni', 'windows', 'QnnModelPal.cpp')
    ]

    def __init__(self, name: str, alias: str, architecture_alias: str, obj_generator_target: str):
        super().__init__(name=name, alias=alias)
        self._arch_alias = architecture_alias
        self._obj_generator_target = obj_generator_target

    def is_available(self) -> bool:
        return (
                is_windows()
                and super().is_available()
                and all([(CONVERTER_SHARE_DIR / file).is_file() for file in WindowsTarget.converter_sources])
        )

    def _prepare_build_directory(self, build_dir: Path, config: ModelLibGenerationConfig):
        build_dir.mkdir(parents=True, exist_ok=True)
        for file in Target.converter_sources + WindowsTarget.converter_sources:
            shutil.copy(CONVERTER_SHARE_DIR / file, build_dir)

        shutil.copy(config.model_cpp, build_dir / 'QnnNetworkModel.cpp')
        if config.model_bin and config.model_bin.is_file():
            if not config.obj_generator_path or not config.obj_generator_path.is_file():
                raise RuntimeError('Windows model lib generator requires a object generator executable')
            shutil.copy(config.model_bin, build_dir)
            # extract *.raw from model.bin
            self._extract_obj_from_bin(build_dir, config)

    @staticmethod
    def _exec_subprocess_cmd(cmd: str) -> sp.CompletedProcess:
        return exec_subprocess_cmd(['powershell.exe', '-ExecutionPolicy', 'Unrestricted', cmd])

    def _extract_obj_from_bin(self, build_dir: Path, config: WindowsModelLibGenerationConfig):
        logger.info('Converting raw files to object files starts')
        prev = Path.cwd()
        os.chdir(build_dir)
        raw_files_dir = Path('obj', 'binary')
        if not raw_files_dir.is_dir():
            raw_files_dir.mkdir(parents=True, exist_ok=True)
        with tarfile.open(config.model_bin) as tar:
            tar.extractall(raw_files_dir)

        raw_files = [x for x in raw_files_dir.iterdir() if x.suffix == '.raw']
        if not raw_files:
            logger.error(f'No raw files extracted from {config.model_bin} to {raw_files_dir}')
            return
        else:
            logger.info(f"Extracted raw files from {config.model_bin} to {raw_files_dir}")

        output_dir = Path('obj', self.alias)
        output_dir.mkdir(parents=True, exist_ok=True)

        for file in raw_files:
            cmd = f"{config.obj_generator_path} {file} {self._obj_generator_target} {output_dir}"
            result = self._exec_subprocess_cmd(cmd)
            if result.returncode or result.stderr:
                logger.error(f'Failed to convert raw files to object files using command:\n{cmd}\nerror message:\n{result.stderr}')
                sys.exit(-1)
        logger.info('Converted raw files to object files successfully')
        os.chdir(prev)

    def build_target(self, build_dir: Path, config: ModelLibGenerationConfig) -> Path:
        self._prepare_build_directory(build_dir, config)

        target_build_dir = build_dir / f'build_{self._arch_alias}'
        build_type = 'Debug' if config.debug_build else 'Release'
        build_cmd = (
            f'mkdir {target_build_dir};'
            f' cd {target_build_dir};'
            f' cmake ../ -T ClangCL -A {self._arch_alias};'
            f' cmake --build ./ --config {build_type};'
        )
        self._exec_subprocess_cmd(build_cmd)

        target_lib = target_build_dir / build_type / 'model-lib-windows.dll'
        if not target_lib.is_file():
            raise RuntimeError(f'Failed to compile model library for target: {self.alias}')
        target_output_dir = config.output_dir / self._arch_alias
        target_output_dir.mkdir(parents=True, exist_ok=True)

        if not config.lib_name.endswith('.dll'):
            config.lib_name = config.lib_name + '.dll'

        output_lib_path = target_output_dir / config.lib_name
        if output_lib_path.is_file():
            output_lib_path.unlink()
        shutil.copy(target_lib, output_lib_path)
        logger.info(f'Generated model library path: {output_lib_path}')
        return output_lib_path

class ModelLibGenerator:
    available_targets = [x for x in [
        LinuxTarget(name='linux-x86_64', alias='x86_64-linux-clang', compiler=CLANG),
        Aarch64AndroidTarget(),
        ArmAndroidTarget(),
        WindowsTarget(name='x86_64-windows', alias='windows-x86_64', architecture_alias='x64', obj_generator_target='windows-x86_64'),
        WindowsTarget(name='aarch64-windows', alias='windows-aarch64', architecture_alias='ARM64', obj_generator_target='aarch64-windows'),
        LinuxTarget(name='qnx710-arm64',
                    alias='aarch64-qnx',
                    compiler='aarch64-unknown-nto-qnx7.1.0-g++'),
        LinuxTarget(name='qhs225-arm64',
                    alias='aarch64-qhs225',
                    compiler='aarch64-unknown-nto-qnx7.1.0-g++'),
        LinuxTarget(name='qos226-arm64',
                    alias='aarch64-qos226',
                    compiler='aarch64-unknown-nto-qnx7.1.0-g++'),
        LinuxTarget(name='qos227-arm64',
                    alias='aarch64-qos227',
                    compiler='aarch64-unknown-nto-qnx7.1.0-g++'),
        LinuxTarget(name='qhs227-arm64',
                    alias='aarch64-qhs227',
                    compiler='aarch64-unknown-nto-qnx7.1.0-g++'),
        LinuxTarget(name='qnx800-arm64',
                    alias='aarch64-qnx800',
                    compiler='aarch64-unknown-nto-qnx8.0.0-g++'),
        LinuxTarget(name='oe-linux-aarch64-gcc11.2',
                    alias='aarch64-oe-linux-gcc11.2',
                    compiler='${TARGET_PREFIX}'+'g++',
                    sys_root='${SDKTARGETSYSROOT}'),
        LinuxTarget(name='openwrt-linux-aarch64-gcc13.3',
                    alias='aarch64-openwrt-linux-gcc13.3',
                    compiler='${QNN_AARCH64_LINUX_OPENWRT_GCC_133}/bin/aarch64-openwrt-linux-g++',
                    sys_root=''),
        LinuxTarget(name='ubuntu-aarch64-gcc11.2',
                    alias='aarch64-ubuntu-gcc11.2',
                    compiler='${TARGET_PREFIX}'+'g++',
                    sys_root='${SDKTARGETSYSROOT}'),
        LinuxTarget(name='oe-linux-aarch64-gcc9.3',
                    alias='aarch64-oe-linux-gcc9.3',
                    compiler='${QNN_AARCH64_LINUX_OE_GCC_93}/sysroots/x86_64-oesdk-linux/usr/bin/aarch64-oe-linux/aarch64-oe-linux-g++',
                    sys_root='${QNN_AARCH64_LINUX_OE_GCC_93}/sysroots/aarch64-oe-linux'),
        LinuxTarget(name='oe-linux-aarch64-gcc8.2',
                    alias='aarch64-oe-linux-gcc8.2',
                    compiler='${QNN_AARCH64_LINUX_OE_GCC_82}/sysroots/x86_64-oesdk-linux/usr/bin/aarch64-oe-linux/aarch64-oe-linux-g++',
                    sys_root='${QNN_AARCH64_LINUX_OE_GCC_82}/sysroots/aarch64-oe-linux'),
        LinuxTarget(name='ubuntu-aarch64-gcc7.5',
                    alias='aarch64-ubuntu-gcc7.5',
                    compiler='aarch64-linux-gnu-g++'),
        LinuxTarget(name='ubuntu-aarch64-gcc9.4',
                    alias='aarch64-ubuntu-gcc9.4',
                    compiler='${QNN_AARCH64_UBUNTU_GCC_94}/usr/bin/aarch64-linux-gnu-g++',
                    sys_root='${QNN_AARCH64_UBUNTU_GCC_94}')
    ] if x.is_available()]

    def __init__(self, clean_up: bool = True):
        self._current_dir = Path.cwd()
        self._clean_up = clean_up
        self._alias_target_map = {t.alias: t for t in ModelLibGenerator.available_targets}

    def __enter__(self):
        suffix = random.randint(1000, 9999)
        while (self._current_dir / f'tmp_{suffix}').is_dir():
            suffix = random.randint(1000, 9999)
        self._build_dir = self._current_dir / f'tmp_{suffix}'
        self._build_dir.mkdir()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._clean_up:
            os.chdir(self._current_dir)
            if self._build_dir.is_dir():
                shutil.rmtree(self._build_dir, ignore_errors=True)

    def _normalize_config(self, config: ModelLibGenerationConfig):
        if not config.model_cpp:
            raise ValueError('model_cpp is required but missing from config')
        elif not config.model_cpp.is_file():
            raise FileNotFoundError(f'Unable to find the model source file, invalid path: {config.model_cpp.absolute()}')
        if config.model_cpp.suffix != '.cpp':
            raise ValueError(f'{config.model_cpp} is not a cpp model file')
        if not config.model_bin:
            logger.warning('Runtime will fail if .cpp needs data in .raw files when building model library')
        elif not config.model_bin.is_file():
            raise FileNotFoundError(f'Unable to find the bin file, invalid path: {config.model_bin.absolute()}')
        elif config.model_bin.suffix != '.bin':
            raise ValueError(f'{config.model_bin} is not a bin model file')

        if not config.targets:
            raise ValueError(f'No library build targets provided')
        targets = set()
        for arg_list in config.targets:
            for arg in arg_list:
                targets.update(arg.split())
        config.targets = list(targets)
        invalid_targets = [t for t in config.targets if t not in self._alias_target_map]

        if invalid_targets:
            raise ValueError(f'Unsupported target(s): {",".join(invalid_targets)}')

        if not config.output_dir:
            config.output_dir = DEFAULT_OUTPUT_DIR

        if config.output_dir.is_file():
            raise ValueError(f'Output directory is a file: {config.output_dir}')

        config.output_dir = config.output_dir.absolute()
        self._clean_up = config.clean_up

    def build_targets(self, config: ModelLibGenerationConfig) -> Dict[str, Path]:
        self._normalize_config(config)
        logger.info(f'Output library name  : {config.lib_name}')

        if not config.output_dir.is_dir():
            config.output_dir.mkdir(parents=True)

        target_lib_map = {}
        for t in config.targets:
            target_lib_map[t] = self._alias_target_map[t].build_target(self._build_dir, config)

        return target_lib_map


def _configure_logger(debug):
    global logger
    logLevel = logging.DEBUG if debug else logging.INFO
    log_prefix = 'qnn-model-lib-generator'
    # Create logger
    logger = logging.getLogger(log_prefix)
    logger.setLevel(logLevel)

    # Create stdout handler (INFO and WARNING only)
    stdout_handler = logging.StreamHandler(sys.stdout)
    stdout_handler.setLevel(logLevel)
    stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)  # Filters out ERROR and CRITICAL

    # Create stderr handler (ERROR and CRITICAL only)
    stderr_handler = logging.StreamHandler(sys.stderr)
    stderr_handler.setLevel(logging.ERROR)

    # Set formatter
    formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)7s - %(name)s: %(message)s')
    stdout_handler.setFormatter(formatter)
    stderr_handler.setFormatter(formatter)

    # Add handlers to logger
    logger.addHandler(stdout_handler)
    logger.addHandler(stderr_handler)

def get_configuration_from_command_line() -> ModelLibGenerationConfig:
    default_targets = {'x86_64-linux-clang', 'aarch64-android', 'windows-x86_64', 'windows-aarch64'}
    available_targets = {x.alias for x in ModelLibGenerator.available_targets}
    default_targets = list(default_targets.intersection(available_targets))

    parser = argparse.ArgumentParser()
    required = parser.add_argument_group('required arguments')
    required.add_argument('-c', '--cpp', required=True, type=str,
                          help="Filepath for the qnn model .cpp file")

    optional = parser.add_argument_group('optional arguments')
    optional.add_argument('-b', '--bin', type=str, required=False,
                          help='Filepath for the qnn model .bin file. (Note: if not passed, '
                               'runtime will fail if .cpp needs any items from a .bin file.)')
    optional.add_argument('-t', '--lib_targets', nargs='+', required=False,
                          action='append', type=str,
                          help='Specifies the targets to build the models for. '
                               f'Supported targets: {", ".join(available_targets)}. '
                               f'Default: [{", ".join(default_targets)}]')
    optional.add_argument('-l', '--lib_name', type=str, required=False,
                          help="Specifies the name to use for libraries. "
                               "Default: uses name in <model>.bin if provided, "
                               "else generic qnn_model")
    optional.add_argument('-o', '--output_dir', type=str, required=False,
                          default=DEFAULT_OUTPUT_DIR,
                          help='Location for saving output libraries. '
                               f'Default: {DEFAULT_OUTPUT_DIR}')
    optional.add_argument('-d', '--debug', action='store_true',
                          help='Set log level as debug mode')
    if is_windows():
        optional.add_argument('--debug_build', action='store_true',
                              help='Set to build for debug version, if this'
                                   'flag is not set, it will build for release'
                                   'version.')
        optional.add_argument('--clean_up', action='store_true',
                              help='Clean-up temporary files in build process')
    args = parser.parse_args()

    _configure_logger(args.debug)
    config = ModelLibGenerationConfig()
    if is_windows():
        config = WindowsModelLibGenerationConfig()
        config.obj_generator_path = Path(__file__).parent.resolve() / 'object-generator.exe'
        config.clean_up = args.clean_up
        config.debug_build = args.debug_build
    config.model_cpp = Path(args.cpp)
    config.model_bin = Path(args.bin) if args.bin else None
    config.targets = args.lib_targets if args.lib_targets else [default_targets]
    config.lib_name = args.lib_name
    if not config.lib_name:
        if config.model_bin:
            config.lib_name = config.model_bin.stem
        else:
            config.lib_name = 'qnn_model'
    config.output_dir = Path(args.output_dir)

    logger.info(f"Model cpp file path  : {config.model_cpp}")
    logger.info(f"Model bin file path  : {config.model_bin}")
    logger.info(f"Library target       : {config.targets}")
    logger.info(f"Library name         : {config.lib_name}")
    logger.info(f"Output directory     : {config.output_dir}")
    if is_windows():
        if config.debug_build:
            logger.info("Build type           : Debug")
        else:
            logger.info("Build type           : Release")
        logger.info(f"Clean-up             : {config.clean_up}")

    return config


def main():
    with ModelLibGenerator() as generator:
        config = get_configuration_from_command_line()
        result = generator.build_targets(config)
        for target, path in result.items():
            logger.info(f'Target: {target}\tLibrary: {path}')


if __name__ == "__main__":
    main()
