# =============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

from dataclasses import dataclass
from typing import Optional, List
from qti.aisw.core.model_level_api.config.sdk_config import Config, RunConfig, GenerateConfig


@dataclass
class QNNCommonConfig(Config):
    log_level: Optional[str] = None
    set_output_tensors: Optional[List[str]] = None
    profiling_level: Optional[str] = None
    profiling_option: Optional[str] = None

    def as_command_line_args(self):
        log_level_arg = f'--log_level {self.log_level}' if self.log_level else ''
        profile_arg = f'--profiling_level {self.profiling_level}' if self.profiling_level else ''
        profiling_option_arg = f'--profiling_option {self.profiling_option}' if self.profiling_option else ''
        return f'{log_level_arg} {profile_arg} {profiling_option_arg}'


@dataclass
class QNNRunConfig(QNNCommonConfig, RunConfig):
    batch_multiplier: Optional[int] = None
    use_native_output_data: Optional[bool] = None
    use_native_input_data: Optional[bool] = None
    native_input_tensor_names: Optional[List[str]] = None
    synchronous: Optional[bool] = None
    debug: Optional[bool] = None
    perf_profile: Optional[str] = None

    def as_command_line_args(self):
        qnn_common_config_args = super().as_command_line_args()
        batch_arg = f'--batch_multiplier {self.batch_multiplier}' if self.batch_multiplier else ''
        native_output_data_arg = '--use_native_output_files' if self.use_native_output_data else ''
        native_input_data_arg = '--use_native_input_files' if self.use_native_input_data else ''
        native_input_tensor_names_arg = '--native_input_tensor_names ' \
            f"{','.join(self.native_input_tensor_names)}" if self.native_input_tensor_names else ''
        synchronous_arg = '--synchronous' if self.synchronous else ''
        debug_arg = '--debug' if self.debug else ''
        perf_profile_arg = f'--perf_profile {self.perf_profile}' if self.perf_profile else ''
        return f'{qnn_common_config_args} {batch_arg} {native_output_data_arg} ' \
               f'{native_input_data_arg} {native_input_tensor_names_arg} {synchronous_arg} ' \
               f'{debug_arg} {perf_profile_arg}'


@dataclass
class QNNGenerateConfig(QNNCommonConfig, GenerateConfig):
    enable_intermediate_outputs: Optional[bool] = None
    input_output_tensor_mem_type: Optional[str] = "raw"
    def as_command_line_args(self):
        qnn_common_config_args = super().as_command_line_args()
        intermediate_outputs = '--enable_intermediate_outputs ' if self.enable_intermediate_outputs else ''
        input_output_tensor_mem_type_arg = f'--input_output_tensor_mem_type {self.input_output_tensor_mem_type}'
        return f'{qnn_common_config_args} {intermediate_outputs} {input_output_tensor_mem_type_arg}'