#!/usr/bin/env python3
# -*- mode: python -*-
#==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
#==============================================================================

import sys
import traceback
import os
import yaml
import argparse

# import numpy before qti.aisw.converters.xxxx modules
import numpy as np

# Common Imports
from qti.aisw.converters.common import ir_graph
from qti.aisw.converters.common.utils import validation_utils
from qti.aisw.converters.common.utils.converter_utils import log_error, log_info, log_warning, log_debug
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter
from qti.aisw.converters.common.converter_ir.op_graph_optimizations import IROptimizations
from qti.aisw.converters.qnn_backend.ir_to_dlc import DLCBackend
from qti.aisw.converters.qnn_backend.custom_ops.op_factory import QnnCustomOpFactory
from qti.aisw.converters.common import encodings_json_serializer, modeltools
from qti.aisw.converters import onnx as onnx_converter
from safetensors.numpy import load_file, save_file
from qti.aisw.dlc_utils.snpe_dlc_utils import ModelInfo
from qti.aisw.converters.common.dlc_quantizer import DLCQuantizer
from qti.aisw.converters.common.backend_awareness import BackendInfo


def get_quantizer_args_from_dlc(model_info):
    quantizer_cmd_str = model_info.read_quantizer_command()
    quantizer_args_namespace = {}
    if quantizer_cmd_str != 'N/A':
        if ';' not in quantizer_cmd_str:
            raise RuntimeError("Commandline parsing is not supported on older DLC."
                               " This support has been added since release v2.25."
                               " Please use SDK v2.25 or later SDK for this functionality")

        quantizer_args_to_value_dict = model_info.parse_quantizer_command()
        quantizer_args_namespace = convert_dict_to_namespace(quantizer_args_to_value_dict)
    return quantizer_args_namespace

def get_converter_args_from_dlc(model_info):
    if ';' not in model_info.read_converter_command():
        raise RuntimeError("Commandline parsing is not supported on older DLC."
                           " This support has been added since release v2.25."
                           " Please use SDK v2.25 or later SDK for this functionality")

    converter_args_to_value_dict = model_info.parse_converter_command()
    converter_args_namespace = convert_dict_to_namespace(converter_args_to_value_dict)

    return converter_args_namespace

class LoraArgParser(ArgParserWrapper):
    def __init__(self):
        super(LoraArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                                      conflict_handler='resolve',
                                                      parents=[])

        self.add_required_argument('--lora_config',
                                   metavar="LORA_CONFIG_YAML",
                                   type=str,
                                   action=validation_utils.validate_filename_arg(must_exist=True),
                                   help='Path to the YAML config file for LoRA.')

        self.add_required_argument('--input_dlc', type=str,
                                    action=validation_utils.validate_filename_arg(must_exist=True),
                                    help='Path to the Float or Quantized DLC.')

        self.add_required_argument("--input_network", "-i", type=str,
                                    action=validation_utils.validate_pathname_arg(must_exist=True),
                                    help="Path to the source ONNX model.")

        group = self.parser.add_mutually_exclusive_group()
        group.add_argument('--input_list', type=str,
                                    action=validation_utils.validate_filename_arg(must_exist=True),
                                    help='Path to a file specifying the input data. This file should be a plain text '
                                         'file, containing one or more absolute file paths per line. Each path is '
                                         'expected to point to a binary file containing one input in the "raw" format, '
                                         'ready to be consumed by the lora-importer without any further preprocessing. '
                                         'See documentation for more details.')

        group.add_argument('--float_fallback', action='store_true', default=False,
                                    help='Use this option to enable fallback to floating point (FP) instead of fixed point. '
                                         'If this option is enabled, then ``--input_list`` must not be provided. '
                                         'The external quantization encodings (encoding file/FakeQuant encodings) '
                                         'might be missing quantization parameters for some interim tensors. '
                                         'First it will try to fill the gaps by propagating across math-invariant '
                                         'functions. If the quantization params are still missing, then it will '
                                         'apply fallback to nodes to floating point.')

        self.add_optional_argument("--debug", type=int, nargs='?', default=-1,
                                   help="Run the qairt-lora-importer in debug mode.")

    @classmethod
    def validate_quantizer_args(cls, lora_args, quantizer_args):
        """
                          |       lora_args
                          | --------------------------------------
                          | input_list| float_fallback | None
        quantizer_args    |-----------|----------------|----------
        ------------------|           |                |
        N/A(FP conversion)| Not Valid | Not Valid      | Valid
        input_list        | Valid     | Not Valid      | Not Valid
        float_fallback    | Not Valid | Valid          | Not Valid

        """
        if quantizer_args.input_list and not lora_args.input_list:
            raise Exception("No --input_list has been passed to qairt-lora-importer, but input DLC was quantized using "
                            "--input_list. Please retry after providing calibration data using --input_list")
        elif quantizer_args.float_fallback and (not lora_args.input_list and not lora_args.float_fallback):
            raise Exception("No --input_list or --float_fallback has been passed to qairt-lora-importer, "
                            "but input DLC was quantized using --input_list. Please retry after providing --float_fallback")

        if (quantizer_args.input_list and lora_args.input_list) and (quantizer_args.input_list != lora_args.input_list):
            log_warning("Input list path used to produce the DLC is different than the provided input list path. "
                        "Using the provided input list path to produce lora artifacts.")


def convert_dict_to_namespace(dictionary):
    namespace = argparse.Namespace()
    for key, value in dictionary.items():
        setattr(namespace, key, value)
    return namespace

def get_lora_use_cases(args):
    with open(args.lora_config) as f:
        lora_configs = yaml.safe_load(f)
        return lora_configs['use_case']

def get_safetensor_tensor_names(args):
    lora_use_cases = get_lora_use_cases(args)
    # all use cases should have the same set of updatable tensors
    weights_path = lora_use_cases[0]['lora_weights']
    weights = load_file(weights_path)
    tensor_names = set(weights.keys())
    return tensor_names

def get_graph_from_dlc(dlc_path):
    dlc_reader = modeltools.IrDlcReader()
    dlc_reader.open(dlc_path)
    graph_names = list(dlc_reader.get_ir_graph_names())
    cpp_graph = dlc_reader.get_ir_graph(graph_names[0])
    return cpp_graph

def get_updateable_static_tensor_names_in_graph(cpp_graph):
    updateable_tensor_names = list()
    tensor_map = cpp_graph.get_tensor_map()
    for tensor_name, tensor in tensor_map.items():
        if tensor.is_updateable() and tensor.is_static_tensor():
            updateable_tensor_names.append(tensor_name)
    return updateable_tensor_names

def load_static_tensors_into_graph(tensor_names_to_data_map, py_graph):
    for tensor_name, tensor_data in tensor_names_to_data_map.items():
        node = py_graph.get_node_by_name(tensor_name)
        node.op.set_tensor_data(tensor_data)

def get_static_tensor_values(safetensor_tensor_names, cpp_graph):
    tensor_values = dict()
    for tensor_name in safetensor_tensor_names:
        tensor = cpp_graph.get_tensor(tensor_name)
        if tensor.is_static_tensor():
            tensor_values[tensor_name] = tensor.get_data()
    return tensor_values

def get_converter_args_for_lora_conversion(args, converter_args_namespace):
    setattr(converter_args_namespace, "input_network", args.input_network)
    setattr(converter_args_namespace, "debug", args.debug)
    # dummy output path needed for dlc backend
    setattr(converter_args_namespace, "output_path", None)

    return converter_args_namespace

def get_quantizer_args_for_lora_quantization(args, quantizer_args_namespace):
    setattr(quantizer_args_namespace, "input_dlc", args.input_dlc)
    setattr(quantizer_args_namespace, "output_dlc", None)
    setattr(quantizer_args_namespace, "debug", args.debug)

    if not args.input_list and args.float_fallback:
        setattr(quantizer_args_namespace, "float_fallback", args.float_fallback)

    if args.input_list:
        if args.float_fallback:
            raise Exception("Quantizer invoked with --float_fallback but lora-importer invoked with --input_list")
        setattr(quantizer_args_namespace, "input_list", args.input_list)

    return quantizer_args_namespace

def get_py_graph(converter_args):
    converter = onnx_converter.OnnxConverterFrontend(converter_args, custom_op_factory=QnnCustomOpFactory())
    py_graph = converter.convert()
    return py_graph

def optimize_py_graph(converter_args, py_graph):

    optimizer = IROptimizations(converter_args)
    optimized_graph = optimizer.optimize(py_graph)
    return optimized_graph

def get_cpp_graph(py_graph, backend):
    backend.prepare_py_graph(py_graph)

    cpp_graph = backend.get_ir_graph(py_graph)
    backend.prepare_cpp_graph(py_graph, cpp_graph)

    return cpp_graph

def get_quantizer(quantizer_args):
    args_dict = DLCQuantizer.ArgParser.validate_and_convert_args(quantizer_args)

    # Backend Awareness
    backend_info_obj = BackendInfo.get_instance(quantizer_args.backend, quantizer_args.soc_model)

    args_dict = DLCQuantizer.ArgParser.validate_and_convert_args(quantizer_args)

    # Backend Awareness
    backend_info_obj = BackendInfo.get_instance(quantizer_args.backend, quantizer_args.soc_model)
    dlc_quantizer = DLCQuantizer(input_dlc=args_dict['input_dlc'],
                                    output_dlc=args_dict['output_dlc'],
                                    input_list=args_dict['input_list'],
                                    float_fallback=args_dict['float_fallback'],
                                    param_quantizer=args_dict['param_quantizer'],
                                    act_quantizer=args_dict['act_quantizer'],
                                    algorithms=args_dict['algorithms'],
                                    bias_bitwidth=args_dict['bias_bitwidth'],
                                    act_bitwidth=args_dict['act_bitwidth'],
                                    weights_bitwidth=args_dict['weights_bitwidth'],
                                    float_bitwidth=args_dict['float_bitwidth'],
                                    float_bias_bitwidth=args_dict['float_bias_bitwidth'],
                                    ignore_encodings=args_dict['ignore_encodings'],
                                    use_per_channel_quantization=args_dict['use_per_channel_quantization'],
                                    use_per_row_quantization=args_dict['use_per_row_quantization'],
                                    use_native_input_files=args_dict['use_native_input_files'],
                                    use_native_output_files=args_dict['use_native_output_files'],
                                    restrict_quantization_steps=args_dict['restrict_quantization_steps'],
                                    use_dynamic_16_bit_weights=args_dict['use_dynamic_16_bit_weights'],
                                    pack_4_bit_weights=args_dict['pack_4_bit_weights'],
                                    keep_weights_quantized=args_dict["keep_weights_quantized"],
                                    adjust_bias_encoding=args_dict["adjust_bias_encoding"],
                                    act_quantizer_calibration=args_dict['act_quantizer_calibration'],
                                    param_quantizer_calibration=args_dict['param_quantizer_calibration'],
                                    act_quantizer_schema=args_dict['act_quantizer_schema'],
                                    param_quantizer_schema=args_dict['param_quantizer_schema'],
                                    percentile_calibration_value=args_dict['percentile_calibration_value'],
                                    use_aimet_quantizer=args_dict['use_aimet_quantizer'],
                                    op_package_lib=args_dict['op_package_lib'],
                                    disable_legacy_quantizer=args_dict['disable_legacy_quantizer'],
                                    dump_encoding_json=args_dict['dump_encoding_json'],
                                    include_data_invariant_ops=args_dict['include_data_invariant_ops'],
                                    aimet_config=args_dict['config_file'],
                                    backend_info_obj=backend_info_obj)
    return dlc_quantizer

def get_quantized_graph(quantizer_args, cpp_graph):
    quantizer = get_quantizer(quantizer_args)
    quantizer.set_ir_graph(cpp_graph)
    quantizer.quantize()
    quantized_cpp_graph = quantizer.ir_graph
    return quantized_cpp_graph

def get_safetensor_dict(config):
    safetensor_path = config['lora_weights']
    if safetensor_path:
        tensor_names_to_data_map = load_file(safetensor_path)
        return tensor_names_to_data_map
    else:
        return None

def create_lora_updates(args, converter_args, quantizer_args):
    def get_save_directory(config):
        save_dir = config['output_path']
        if save_dir:
            # Check and create the directory if it doesn't exist
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            return save_dir
        else:
            return os.getcwd()

    def save_encoding(cpp_graph, config):
        include_data_invariant_ops = True
        use_case_name = config['name']
        filename = use_case_name + "_lora.encodings"
        save_dir = get_save_directory(config)
        save_path = os.path.join(save_dir, filename)
        serializer = encodings_json_serializer.IrEncodingsJsonSerializer(
                save_path,
                include_data_invariant_ops)
        serializer.serialize(cpp_graph)
        encoding_json = serializer.get_graph_json()
        with open(save_path, "w") as f:
            f.write(encoding_json)
            log_info(use_case_name + " encodings JSON saved at: " + save_path)
        return save_path

    def save_updatable_static_tensors(config, cpp_graph, updateable_static_tensor_names):
        use_case_name = config['name']
        filename = use_case_name + "_lora.safetensors"
        save_dir = get_save_directory(config)
        save_path = os.path.join(save_dir, filename)

        tensor_names_to_data = dict()
        for tensor_name in updateable_static_tensor_names:
            if not cpp_graph.has_tensor(tensor_name):
                log_warning("Warning: Tensor name, {}, in not found in the DLC graph but not in the converted onnx graph.".format(tensor_name))
            else:
                cpp_tensor = cpp_graph.get_tensor(tensor_name)
                cpp_tensor_dtype = cpp_tensor.data_type()
                tensor_data = cpp_tensor.get_data()

                # get_data() returns uint16 buffer for fp16 datatype
                # Read the uint16 buffer in fp16, then convert to fp32
                if cpp_tensor_dtype == ir_graph.QNN_DATATYPE_FLOAT_16:
                    tensor_data_fp16 = np.frombuffer(tensor_data.tobytes(), dtype=np.float16)
                    tensor_data_fp16 = np.reshape(tensor_data_fp16, tensor_data.shape)
                    tensor_data = tensor_data_fp16.astype(np.float32)

                tensor_names_to_data[tensor_name] = tensor_data

        save_file(tensor_names_to_data, save_path)
        log_info(use_case_name + " safetensors saved at: " + save_path)

        return save_path

    def set_quant_overrides(converter_args, usecase_config):
        if "quant_overrides" in usecase_config:
            converter_args.quantization_overrides = usecase_config["quant_overrides"]
        else:
            converter_args.quantization_overrides = None



    dlc_graph = get_graph_from_dlc(args.input_dlc)
    updateable_static_tensor_names = get_updateable_static_tensor_names_in_graph(dlc_graph)
    lora_configs = get_lora_use_cases(args)
    output_config_data = dict()

    should_quantize = False
    if (args.input_list or args.float_fallback) and quantizer_args:
        should_quantize = True

    for config in lora_configs:
        set_quant_overrides(converter_args, config)

        py_graph = get_py_graph(converter_args)

        tensor_names_to_data_map = get_safetensor_dict(config)
        if (tensor_names_to_data_map):
            # Safetensors are passed in lora_config
            load_static_tensors_into_graph(tensor_names_to_data_map, py_graph)
        else:
            # Safetensors are not passed in lora_config
            if converter_args.lora_weight_list != "":
                # error out as qairt-converter was passed with valid lora_weight_list and
                # qairt-lora-importer should have safetensors files as well
                raise ValueError(f"No Safetensors passed in lora_config but the Model has lora adapter weights")
            else:
                log_debug("Model does not have any lora adapter weights and no safetensors are passed in lora_config")

        optimized_py_graph = optimize_py_graph(converter_args, py_graph)

        backend = DLCBackend(converter_args)
        cpp_graph = get_cpp_graph(optimized_py_graph, backend)

        safetensor_save_path = save_updatable_static_tensors(config, cpp_graph, updateable_static_tensor_names)

        use_case_name = config['name']
        output_config_data[use_case_name] = dict()

        if should_quantize:
            cpp_graph = get_quantized_graph(quantizer_args, cpp_graph)
            encoding_save_path = save_encoding(cpp_graph, config)
            output_config_data[use_case_name]['encodings'] = encoding_save_path

        output_config_data[use_case_name]['safetensor'] = safetensor_save_path
        output_config_data[use_case_name]['graph_name'] = cpp_graph.name

    save_dir = get_save_directory(config)
    save_path = os.path.join(save_dir, "lora_output_files.yaml")

    with open(save_path, "w") as f:
        yaml.dump(output_config_data, f)

def main():
    parser = LoraArgParser()
    args = parser.parse_args()

    try:
        model_info = ModelInfo(args.input_dlc)
        converter_args = get_converter_args_from_dlc(model_info)
        quantizer_args = get_quantizer_args_from_dlc(model_info)
        if quantizer_args:
            parser.validate_quantizer_args(args, quantizer_args)
            quantizer_args = get_quantizer_args_for_lora_quantization(args, quantizer_args)
        else:
            if args.input_list or args.float_fallback:
                raise Exception("Input DLC is not quantized. --input_list or --float_fallback is not a valid argument "
                                "in this case. Please retry after removing this argument")
        converter_args = get_converter_args_for_lora_conversion(args, converter_args)
        create_lora_updates(args, converter_args, quantizer_args)

    except Exception as e:
        log_error("Encountered Error: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)
    sys.exit(0)

if __name__ == '__main__':
    main()
