# =============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import numpy as np
from typing import Tuple
import json
from importlib import import_module
import os

from qti.aisw.accuracy_debugger.lib.utils.nd_exceptions import UnsupportedError
from qti.aisw.accuracy_debugger.lib.utils.nd_constants import FrameworkExtension
from qti.aisw.accuracy_debugger.lib.framework_runner.nd_framework_objects import get_available_frameworks

import sys
if sys.version_info < (3, 8):
    # distutils deprecated for Python 3.8 and up
    from distutils.version import StrictVersion as Version
else:
    # packaging requires Python 3.8 and up
    from packaging.version import Version as Version


def load_inputs(data_path, data_type, data_dimension=None):
    # type:  (str, str, Tuple) -> np.ndarray
    data = np.fromfile(data_path, data_type)
    if data_dimension is not None:
        data = data.reshape(data_dimension)
    return data


def save_outputs(data, data_path, data_type):
    # type:  (np.ndarray, str, str) -> None
    data.astype(data_type).tofile(data_path)


def read_json(json_path):
    with open(json_path) as f:
        data = json.load(f)
    return data


def dump_json(data, json_path):
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)


def transpose_to_nhwc(data, data_dimension):
    # type:  (np.ndarray, list) ->np.ndarray
    if len(data_dimension) == 4:
        data = np.reshape(
            data, (data_dimension[0], data_dimension[1], data_dimension[2], data_dimension[3]))
        data = np.transpose(data, (0, 2, 3, 1))
        data = data.flatten()
    return data


class ModelHelper:

    @classmethod
    def onnx_type_to_numpy(cls, type):
        """
        This method gives the corresponding numpy datatype for given onnx tensor element type
        Args:
            type : onnx tensor element type
        Returns:
            corresponding onnx datatype
        """
        onnx_to_numpy = {
            '1': (np.float32, 4),
            '2': (np.uint8, 1),
            '3': (np.int8, 1),
            '4': (np.uint16, 2),
            '5': (np.int16, 2),
            '6': (np.int32, 4),
            '7': (np.int64, 8),
            '9': (np.bool_, 1)
        }
        if type in onnx_to_numpy:
            return onnx_to_numpy[type]
        else:
            raise UnsupportedError('Unsupported type : {}'.format(str(type)))


def get_framework_info(model_path):
    """
    Tries to find framework name of given model_path based on it's extension.
    Returns: Framework name (None if not able to find framework name)
    """
    if model_path is None:
        return None
    extenstion_framework_map = {v: k for k, v in FrameworkExtension.framework_extension_mapping.items()}
    model_extension = '.' + model_path.rsplit('.', 1)[-1]
    return extenstion_framework_map.get(model_extension, None)


def extract_input_information(input_tensor):
    input_info = {}
    in_list = list(zip(*input_tensor))
    if len(in_list) == 4:
        (in_names, in_dims, in_data_paths, in_types) = in_list
    elif len(in_list) == 3:
        (in_names, in_dims, in_data_paths) = in_list
        in_types = None
    else:
        raise FrameworkError(get_message('ERROR_FRAMEWORK_RUNNER_INPUT_TENSOR_LENGHT_ERROR'))

    input_names = list(in_names)
    input_dims = [[int(x) for x in dim.split(',')] for dim in in_dims]

    if len(input_names)!=len(input_dims):
        return None

    for i,input_name in enumerate(input_names):
        input_info[input_name] = input_dims[i]

    return input_info


def max_version(framework, available_frameworks):
    versions = available_frameworks.get(framework, {})
    return max(versions.keys(), key=lambda x: Version(x))


def simplify_onnx_model(logger, model_path=None, input_tensor=None, output_dir=None, custom_op_lib=None):
    framework = 'onnx'
    available_frameworks = get_available_frameworks()
    version = max_version(framework, available_frameworks)
    module, framework_class = available_frameworks[framework][version]
    framework_type = getattr(import_module(module), framework_class)
    framework_instance = framework_type(logger, custom_op_lib=custom_op_lib)

    optimized_model_path = os.path.join(output_dir,
                                        "optimized_model" + FrameworkExtension.framework_extension_mapping[framework])
    input_information = extract_input_information(input_tensor)
    _, optimized_model_path = framework_instance.optimize(
        model_path, optimized_model_path, input_information)

    return optimized_model_path