# =============================================================================
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# All rights reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import copy
import logging
import numpy as np
import os
import json
from itertools import chain, product

from qti.aisw.accuracy_evaluator.qacc import *
from qti.aisw.accuracy_evaluator.qacc import qacc_file_logger, qacc_logger
from qti.aisw.accuracy_evaluator.common.utilities import *
from qti.aisw.accuracy_evaluator.qacc.constants import Constants as qcc
from qti.aisw.accuracy_evaluator.common.infer_engines.QAIRTInferenceEngine import QAIRTInferenceEngine
from qti.aisw.accuracy_evaluator.qacc.config_definitions import *


class InferenceManager:

    def __init__(self, inference_schema_config, infer_config, binary_path):
        self.inference_schema_config = inference_schema_config
        self.binary_path = binary_path
        self.infer_config = infer_config
        # capture execution time
        # (quantization time, compilation time, inference time)
        self.execution_time = [0, 0, 0]

    def execute(self, model_path, output_dir, input_file, output_file, calibration, device_id,
                precompiled_path, console_tag, compile_only, qnn_sdk_dir=""):
        if self.inference_schema_config.name == InferenceEngineType.QNN:
            return self.execute_qairt(model_path, output_dir, input_file, output_file, device_id,
                                      precompiled_path, console_tag, calibration=calibration,
                                      compile_only=compile_only, qnn_sdk_dir=qnn_sdk_dir)
        elif self.inference_schema_config.name == InferenceEngineType.ONNXRT:
            return self.execute_onnxrt(model_path, output_dir, input_file, output_file)
        elif self.inference_schema_config.name == InferenceEngineType.TFRT:
            return self.execute_tfrt(model_path, output_dir, input_file, output_file)
        elif self.inference_schema_config.name == InferenceEngineType.TORCHSCRIPTRT:
            return self.execute_torchscriptrt(model_path, output_dir, input_file, output_file)
        elif self.inference_schema_config.name == InferenceEngineType.TFRT_SESSION:
            return self.execute_tfrt_session(model_path, output_dir, input_file, output_file)

        assert ('Invalid inference schema name ' + self.inference_schema_config.name.value)

    def execute_qairt(self, model_path, output_dir, input_file, output_file, device_id,
                      precompiled_path, console_tag, calibration=None, compile_only=False,
                      qnn_sdk_dir=""):

        backend = self.inference_schema_config.backend
        precision = self.inference_schema_config.precision
        target_arch = self.inference_schema_config.target_arch
        backend_extensions = self.inference_schema_config.backend_extensions
        netrun_params = self.inference_schema_config.netrun_params
        quantizer_params = self.inference_schema_config.quantizer_params
        converter_params = self.inference_schema_config.converter_params
        contextbin_params = self.inference_schema_config.contextbin_params

        calibration_file = None
        if quantizer_params and quantizer_params.float_fallback:
            # when float_fallback is set to True, calibration_file should be None
            calibration_file = None
        elif calibration and self.inference_schema_config.precision == PrecisionType.QUANT:
            calibration_file = self.parse_generate_calibration(calibration, input_file,
                                                               os.path.dirname(input_file))

        os.makedirs(output_dir, exist_ok=True)
        engine = QAIRTInferenceEngine(
            model_path=model_path, inputlistfile=input_file, calibration_file=calibration_file,
            output_path=output_dir, inputs_info=self.infer_config.inputs_info,
            outputs_info=self.infer_config.outputs_info, gen_out_file=output_file,
            backend_extensions=backend_extensions, netrun_params=netrun_params,
            quantizer_params=quantizer_params, converter_params=converter_params,
            contextbin_params=contextbin_params, backend=backend, precision=precision,
            target_arch=target_arch, device_id=device_id)
        if converter_params:
            # dumping all converter params before execution
            outfile = os.path.join(output_dir, 'converter_params_list.json')
            data = (
                f'converter_params: {self.inference_schema_config.converter_params.model_dump_json(exclude_unset=True)}'
            )
            with open(outfile, 'w', encoding='utf-8') as f:
                f.write(data)
        if quantizer_params:
            # dumping all quantizer params before execution
            outfile = os.path.join(output_dir, 'quantizer_params_list.json')
            data = f'quantizer_params: {self.inference_schema_config.quantizer_params.model_dump_json(exclude_unset=True)}'
            with open(outfile, 'w', encoding='utf-8') as f:
                f.write(data)
        try:
            engine.execute()
            ret_status = True
            qacc_file_logger.info('Inference success on QNN in execution stage.')
        except Exception as e:
            qacc_logger.info(e)
            qacc_file_logger.error('Inference failed on QNN in execution stage.')
            ret_status = False
        finally:
            infer_stages_status = engine.stage_status

        infer_fail_stage = self._get_first_fail_stage(infer_stages_status)
        return not ret_status, infer_fail_stage, [0, 0, 0]

    def execute_onnxrt(self, model_path, output_dir, input_file, output_file):
        from qti.aisw.accuracy_evaluator.common.infer_engines.OnnxRTEngine import OnnxInferenceEngine
        engine = OnnxInferenceEngine(
            model=model_path, inputlistfile=input_file,
            multithread=self.inference_schema_config.multithreaded, output_path=output_dir,
            input_info=self.infer_config.inputs_info, output_info=self.infer_config.outputs_info,
            gen_out_file=output_file, convert_nchw=self.inference_schema_config.convert_nchw)

        try:
            status, _, self.execution_time[2], _ = engine.execute()
            infer_fail_stage = None
        except Exception as e:
            qacc_logger.error('(onnxrt) Inference failed. See qacc.log for more details.')
            qacc_file_logger.error('Exception - {}'.format(e))
            status = 0
            infer_fail_stage = 'onnx-inference'

        return not status, infer_fail_stage, self.execution_time

    def execute_tfrt(self, model_path, output_dir, input_file, output_file):
        from qti.aisw.accuracy_evaluator.common.infer_engines.TensorflowRTEngine import TensorflowInferenceEngine
        engine = TensorflowInferenceEngine(model=model_path, inputlistfile=input_file,
                                           multithread=self.inference_schema_config.multithreaded,
                                           output_path=output_dir,
                                           input_info=self.infer_config.inputs_info,
                                           output_info=self.infer_config.outputs_info,
                                           gen_out_file=output_file)
        try:
            status, _, self.execution_time[2], _ = engine.execute()
            infer_fail_stage = None
        except Exception as e:
            qacc_logger.error('tensorflow runtime inference failed. See qacc.log for more details.')
            qacc_file_logger.error('Exception - {}'.format(e))
            status = 0
            infer_fail_stage = 'tensorflow-inference'

        return not status, infer_fail_stage, self.execution_time

    def execute_torchscriptrt(self, model_path, output_dir, input_file, output_file):
        from qti.aisw.accuracy_evaluator.common.infer_engines.TorchScriptRTEngine import TorchScriptInferenceEngine
        engine = TorchScriptInferenceEngine(model=model_path, inputlistfile=input_file,
                                            multithread=self.inference_schema_config.multithreaded,
                                            output_path=output_dir,
                                            input_info=self.infer_config.inputs_info,
                                            output_info=self.infer_config.outputs_info,
                                            gen_out_file=output_file)
        try:
            status, _, self.execution_time[2], _ = engine.execute()
            infer_fail_stage = None
        except Exception as e:
            qacc_logger.error(
                'torchscript runtime inference failed. See qacc.log for more details.')
            qacc_file_logger.error('Exception - {}'.format(e))
            status = 0
            infer_fail_stage = 'torchscript-inference'

        return not status, infer_fail_stage, self.execution_time

    def execute_tfrt_session(self, model_path, output_dir, input_file, output_file):
        from qti.aisw.accuracy_evaluator.common.infer_engines.TensorflowSessionRTEngine import TensorflowSessionInferenceEngine
        engine = TensorflowSessionInferenceEngine(
            model=model_path, inputlistfile=input_file,
            multithread=self.inference_schema_config.multithreaded, output_path=output_dir,
            input_info=self.infer_config.inputs_info, output_info=self.infer_config.outputs_info,
            gen_out_file=output_file)
        try:
            status, _, self.execution_time[2], _ = engine.execute()
            infer_fail_stage = None
        except Exception as e:
            qacc_logger.error('tensorflow runtime inference failed. See qacc.log for more details.')
            qacc_file_logger.error('Exception - {}'.format(e))
            status = 0
            infer_fail_stage = 'tensorflow-session-inference'

        return not status, infer_fail_stage, self.execution_time

    def _parse_range(self, index_str):
        if len(index_str) == 0:
            return []
        nums = index_str.split("-")
        assert len(nums) <= 2, 'Invalid range in calibration file '
        start = int(nums[0])
        end = int(nums[-1]) + 1
        return range(start, end)

    def parse_generate_calibration(self, calibration, input_file, output_dir):
        if calibration is None or input_file is None:
            return None
        (calib_type, calib_file) = calibration

        if calib_type == CalibrationType.RAW:
            return calib_file
        elif calib_type == CalibrationType.INDEX:
            cf = open(calib_file, 'r')
            indexes_str = cf.read().replace('\n', ',').strip()
            indexes = sorted(
                set(chain.from_iterable(map(self._parse_range, indexes_str.split(",")))))
            cf.close()
            _path = os.path.join(output_dir, 'calibration.txt')
            qacc_file_logger.info('Generating calibration file')
            with open(input_file) as f, open(_path, 'w') as f2:
                for index, line in enumerate(f):
                    if index in indexes:
                        f2.write(line)
            return _path
        else:
            raise RuntimeError('Invalid calibration type {}'.format(calib_type))

    def _get_first_fail_stage(self, stage_status):
        for stage in stage_status:
            if stage_status[stage] == False:
                return stage
        return ""


class InferenceSchemaManager:

    def __init__(self, inference_schemas, config):
        self.inference_schemas = inference_schemas
        self.device_ids = config.inference_config.device_ids
        self.schedule = None

    def scan_and_add_inference_schema_permutations(self):
        """Scans the inference_schema section and finds all the possible
        inference schema permutations. Once the scan is complete, these
        possible inference schema permutations are added to the existing
        inference schema list.

        example:
        Given an inference schema
            inference_schema:
                name: qnn
                precision: <value>
                quantizer_params:
                    param1: input1 | input2
                    param2: input3 | input4
                    param3: range(2.0, 4.0, 1.0) # all values from 2 to 4 with step-size 1

        will create following inference schemas
            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input1
                    param2: input3
                    param3: 2.0
            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input1
                    param2: input3
                    param3: 3.0
            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input1
                    param2: input3
                    param3: 4.0

            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input1
                    param2: input4
                    param3: 2.0
            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input1
                    param2: input4
                    param3: 3.0
            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input1
                    param2: input4
                    param3: 4.0

            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input2
                    param2: input3
                    param3: 2.0
            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input2
                    param2: input3
                    param3: 3.0
            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input2
                    param2: input3
                    param3: 4.0

            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input2
                    param2: input4
                    param3: 2.0
            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input2
                    param2: input4
                    param3: 3.0
            inference_schema:
                name: qnn
                precision: <value>
                params:
                    param1: input2
                    param2: input4
                    param3: 4.0
        """
        # updated inference schemas consisting of original plus newly
        # generated inference schemas
        updated_inference_schemas = []

        # used to perform calibration if int8 inference schema available
        is_calib_req = False

        for inference_schema in self.inference_schemas:

            if (inference_schema.name != InferenceEngineType.QNN):
                qacc_file_logger.debug('scan_and_add: Non QNN inference schema {} added'.format(
                    inference_schema.name.value))
                updated_inference_schemas.append(inference_schema)
                continue

            # get nested list of values
            param_values = []
            param_keys = []

            quant_params_dict = {}
            dummy_inp_list = None
            if inference_schema.quantizer_params:
                # create a dictionary of quantizer params and their corresponding values
                # {param1 : [val1, val2], param2 : [val3, val4], param3: [val5, val6, val7]}
                for k, v in inference_schema.quantizer_params:
                    if k in inference_schema.quantizer_params.model_fields_set:
                        if str(k) == 'input_list':
                            dummy_inp_list = v
                            continue
                        # convert val to list to enable product of values
                        if not isinstance(v, list):
                            v = [v]
                        quant_params_dict[k] = v
                # Product of the dict values will create a list of tuples. Each tuple will have
                # one val corresponding to each param
                # param_values = [(val1, val3, val5), (val1, val3, val6), (val1, val3, val7),
                # (val1, val4, val5), (val1, val4, val6), (val1, val4, val7), (val2, val3, val5),
                # (val2, val3, val6), (val2, val3, val7), (val2, val4, val5), (val2, val4, val6),
                # (val2, val4, val7)]
                param_values = list(product(*quant_params_dict.values()))
                param_keys = quant_params_dict.keys()
                qacc_file_logger.debug('scan_and_add: Options for keys-{} values-{} added'.format(
                    param_keys, param_values))
                for param in param_values:
                    # Each 'param' contains values for given quantizer params
                    # param1 -> val1
                    # param2 -> val3
                    # param3 -> val5

                    # param1 -> val1
                    # param2 -> val3
                    # param3 -> val6
                    # and so on, which will be zipped together to create a dictionary, to
                    # be used to create the QuantizerParams object
                    new_inference_schema = copy.deepcopy(inference_schema)
                    new_quant_params_dict = dict(zip(param_keys, param))
                    new_quant_params_dict['input_list'] = dummy_inp_list
                    new_inference_schema.quantizer_params = QuantizerParams(**new_quant_params_dict)
                    updated_inference_schemas.append(new_inference_schema)
                    qacc_file_logger.debug(updated_inference_schemas)

            else:
                updated_inference_schemas.append(inference_schema)

            # check whether for current inference schema calibration is needed.
            # The key is needed in estimating disk space and performing
            # preprocessing for calibration inputs.
            if not is_calib_req:
                # check only if is_calib_req is False
                # if even inference schema needs calibration then this field will be True
                is_calib_req = (inference_schema.precision == PrecisionType.QUANT
                                and inference_schema.precompiled_path is None)

        for up_inference_schema in updated_inference_schemas:
            qacc_file_logger.info('Inference schema: {} - params: {}'.format(
                up_inference_schema.name.value, up_inference_schema.quantizer_params))

        # updating inference schema list
        self.inference_schemas = updated_inference_schemas

        return updated_inference_schemas, is_calib_req


    def create_schedule(self):
        """Creates a schedule based on distributed inference strategy.

        A schedule has following format:
            [parallel_chuck_1, parallel_chuck_2, ... , parallel_chuck_n]

        Each parallel chunk has following format:
            [(inference_schema_idx, device_id), ... , (inference_schema_idx, device_id)]

        Note: device_id for inference_schemas other than aic is -1

        example:
            case1:
                device_ids = [0,1]
                inference_schemas = [onnx, aic, aic, aic, aic]
                schedule = [[(0,-1), (1,0), (2,1)], [(3,0), (4,1)]]
        """

        self.schedule = []
        slots = len(self.device_ids)
        distributed_inference_schemas = []
        used_slots = 0

        for idx, inference_schema in enumerate(self.inference_schemas):
            if inference_schema.name == InferenceEngineType.AIC:

                # if all slots filled
                if used_slots == slots:
                    self.schedule.append(copy.deepcopy(distributed_inference_schemas))
                    distributed_inference_schemas = []
                    used_slots = 0

                distributed_inference_schemas.append((idx, int(self.device_ids[used_slots])))

                # inc used slots
                used_slots += 1

            else:
                # device id for non aic inference schema is -1
                distributed_inference_schemas.append((idx, self.device_ids[0]))

        # copy the last chuck
        self.schedule.append(copy.deepcopy(distributed_inference_schemas))
        qacc_file_logger.info('Distributed schedule: {}'.format(self.schedule))

    def get_schedule(self):
        return self.schedule
