#!/usr/bin/env python3
# -*- mode: python -*-
# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import sys
import traceback
import os

# import numpy before qti.aisw.converters.xxxx modules
import numpy
import pathlib
import shutil

# Common Imports
from qti.aisw.converters.common.utils import code_to_message
from qti.aisw.converters.common.utils.converter_utils import log_error, log_warning, log_info, log_debug
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter
from qti.aisw.converters.common.converter_ir.op_graph_optimizations import IROptimizations
from qti.aisw.converters.qnn_backend.ir_to_dlc import DLCBackend
from qti.aisw.converters.qnn_backend.custom_ops.op_factory import QnnCustomOpFactory
from qti.aisw.converters.common.qairt_converter_arguments import QairtConverterFrontendArgParser, convert_args_v2_to_v1
from qti.aisw.converters.common.utils.multi_graph import IrStaticTensorSet
from qti.aisw.converters.common.model_validator import Validator
from qti.aisw.converters.common.backend_awareness import BackendInfo
from qti.aisw.converters.common.input_shape import InputShapeArgParser, InputShapeInfo


class FrameworktoQNNArgParser(ArgParserWrapper):
    def __init__(self):
        super(FrameworktoQNNArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                                      conflict_handler='resolve',
                                                      parents=[QairtConverterFrontendArgParser(),
                                                               IROptimizations.ArgParserv2(),
                                                               DLCBackend.ArgParserv2(),
                                                               BackendInfo.ArgParser(),
                                                               ])

    def validate_args(self, args):
        if hasattr(args, "lora_weight_list") and args.lora_weight_list != None:
            input_shape_info = InputShapeArgParser(args.input_dim).input_shape_info
            if input_shape_info.has_dynamic_shapes:
                raise Exception("Dynamic tensors are not supported with LoRA with this tool.")

def set_optimization_args(args, framework):
    # TODO: Align optimizations for all frameworks
    if framework == 'onnx':
        args.expand_gru_op_structure = True
        args.unroll_gru_time_steps = True
        args.expand_sparse_op_structure = False

    if (framework == 'onnx' and not args.use_onnx_relay) or framework == 'pytorch':
        args.perform_axes_to_spatial_first_order = True
        args.preprocess_roi_pool_inputs = True

    if framework == 'onnx' or framework == 'tensorflow':
        args.unroll_lstm_time_steps= True
        args.align_matmul_ranks = True
        args.handle_gather_negative_indices = True

    if framework == 'tensorflow' or framework == 'pytorch':
        args.match_caffe_ssd_to_tf = True

    # Enable/Disable following optimizations for onnx, tf, pytorch
    if framework != 'tflite':
        args.squash_box_decoder = True
        args.adjust_nms_features_dims = True
        args.extract_color_transform = True
        args.inject_cast_for_gather = True
        args.force_prune_cast_ops = False

def get_frontend_converter(framework, args, validator):
    if framework == 'onnx':
        if not args.use_onnx_relay:
            from qti.aisw.converters.onnx.onnx_to_ir import OnnxConverterFrontend
            return OnnxConverterFrontend(args, custom_op_factory=QnnCustomOpFactory(), validator=validator)
        else:
            try:
                # use onnx-relay-converter flow
                from qti.aisw.converters.onnx.onnx_to_ir_relay import OnnxRelayConverterFrontend
                return OnnxRelayConverterFrontend(args, custom_op_factory=QnnCustomOpFactory())
            except Exception as e:
                raise Exception("--use_onnx_relay is not available. Please remove --use_onnx_relay in converter command.")
    elif framework == "tensorflow":
        from qti.aisw.converters.tensorflow.tf_to_ir import TFConverterFrontend
        from qti.aisw.converters.tensorflow.util import ConverterError
        if not args.input_dim or not args.out_names:
            raise Exception("--desired_input_shape and --out_tensor_node are required for TensorFlow conversion")
        return TFConverterFrontend(args, custom_op_factory=QnnCustomOpFactory(), validator=validator)
    elif framework == "tflite":
        from qti.aisw.converters.tflite.tflite_to_ir import TFLiteConverterFrontend
        return TFLiteConverterFrontend(args, custom_op_factory=QnnCustomOpFactory())
    elif framework == "pytorch":
        from qti.aisw.converters.pytorch.pytorch_to_ir import PyTorchConverterFrontend
        from qti.aisw.converters.relay.custom_ops.utils.pytorch_helpers import PytorchCustomOpFactory
        if not args.input_dim:
            raise Exception("--desired_input_shape is required for PyTorch conversion")
        return PyTorchConverterFrontend(args, custom_op_factory=PytorchCustomOpFactory())
    else:
        raise Exception(f"unrecognized framework {framework}")

def get_num_tensor_configs(tensor_configs):
    # for when there is only one tensor config e.g. tensor_configs == (1,2,3)
    if isinstance(tensor_configs[0], int):
        return 1

    # for when input_dims is passed individually via CLI
    elif isinstance(tensor_configs, str):
        return 1

    # for when there is multiple tensor configs e.g. tensor_configs == ((1,2,3), (4,5,6))
    else:
        return len(tensor_configs)

def get_num_graph_configs(args):
    def validate_num_configs_is_1_or_n(num_tensor_configs_seen):
        error_message = "Error: Number of tensor configurations can either be 1 or N. \
                       You specified the following number of tensor configurations: {}" \
            .format(num_tensor_configs_seen)
        if len(num_tensor_configs_seen) > 2:
            log_error(error_message)
        elif len(num_tensor_configs_seen) == 2:
            if 1 not in num_tensor_configs_seen:
                log_error(error_message)

    if args.input_dim is None:
        return 1

    num_tensor_configs_seen = set()
    for tensor_name, tensor_configs in args.input_dim:
        num_tensor_configs = get_num_tensor_configs(tensor_configs)
        num_tensor_configs_seen.add(num_tensor_configs)
    validate_num_configs_is_1_or_n(num_tensor_configs_seen)
    return max(num_tensor_configs_seen)

def get_graph_configs(args):
    def convert_dimensions_to_string(dims):
        return ",".join([str(dim) for dim in dims])

    num_configurations = get_num_graph_configs(args)
    configurations = []

    for i in range(num_configurations):
        configuration = []
        for tensor_name, tensor_configs in args.input_dim:
            if get_num_tensor_configs(tensor_configs) > 1:
                tensor_dims = convert_dimensions_to_string(tensor_configs[i])
            else:
                tensor_dims = convert_dimensions_to_string(tensor_configs)
            configuration.append([tensor_name, tensor_dims])
        configurations.append(configuration)
    return configurations

def set_graph_configs(args, config):
    args.input_dim = config

def infer_framework(args):
    input_model_to_framework = {'.onnx': 'onnx', '.pb': 'tensorflow', '.pt': 'pytorch', '.tflite': 'tflite', '.gguf': 'gguf'}
    model_path, model_ext = os.path.splitext(args.input_network)

    # tensorflow2 takes as input a folder which would have the ".pb" file
    if model_ext not in input_model_to_framework:
        model_files = os.listdir(model_path)
        for file in model_files:
            file_ext = os.path.splitext(file)[1]
            if file_ext == '.pb':
                model_ext = '.pb'

    if model_ext not in input_model_to_framework:
        raise Exception("Invalid model format specified. Supported types are .onnx/.pb/.tflite/.pt")
    framework = input_model_to_framework[model_ext]
    return framework

def get_validator(framework, args):
    validator = None
    if (framework == 'onnx' or framework == 'tensorflow') and args.validate_models:
        if args.converter_op_package_lib:
            log_warning("Model is having custom ops skipping validation.")
            args.validate_models = False
        else:
            validator = Validator()
    return validator

def get_tensor_names_from_file(filepath):
    if filepath:
        with open(filepath, 'r') as file:
            tensor_names = file.read()
            tensor_names = tensor_names.splitlines()
            return tensor_names
    else:
        # --lora_weight_list can be passed without any input file,
        # when input Model does not have LoRA branches but activation
        # encodings can change based on selected LoRA adapter during run time
        return None

def make_tensors_updateable(cpp_graph, tensor_names):
    for tensor_name in tensor_names:
        if not cpp_graph.has_tensor(tensor_name):
            error_message = "Error: Tensor name, {}, not found in the graph.".format(tensor_name)
            log_error(error_message)
        log_debug("Marking tensor {} as updatable in the graph ".format(tensor_name))
        cpp_tensor = cpp_graph.get_tensor(tensor_name)
        cpp_tensor.set_updateable(True)



def convert_with_multiple_shapes(args, framework: str):
    validator = get_validator(framework, args)
    set_optimization_args(args, framework)
    optimizer = IROptimizations(args)
    backend = DLCBackend(args)
    backend.initialize()
    # Backend Awareness
    backend_info_obj = BackendInfo.get_instance(args.backend, args.soc_model)
    src_model_tensor_set = set()
    graph_configs = get_graph_configs(args)
    validate_static_tensors = True

    if hasattr(args, 'no_static_tensor_validation') and args.no_static_tensor_validation:
        validate_static_tensors = False

    # In network specialization flow, we will avoid check if the shared context
    # static tensors is already present in DLC
    enable_tensor_deduplication = False

    # The enable_tensor_deduplication flag will enable serializer to look for shared context
    # static tensors data in DLC
    if hasattr(args, 'enable_tensor_deduplication') and args.enable_tensor_deduplication:
        enable_tensor_deduplication = True

    for config in graph_configs:
        set_graph_configs(args, config)
        converter = get_frontend_converter(framework, args, validator)
        if not src_model_tensor_set:
            for tensor in converter.loader.model.graph.initializer:
                src_model_tensor_set.add(tensor.name)
        python_ir_graph = converter.convert()
        optimized_graph = optimizer.optimize(python_ir_graph, backend_info_obj)
        backend.serialize(optimized_graph, src_model_tensor_set, validate_static_tensors, enable_tensor_deduplication)
        del optimized_graph
        del python_ir_graph
        del converter

    backend.finish()

def build_onnx_graph_from_gguf(args):
    try:
        from qti.aisw.converters.llm_builder import llm_builder
    except Exception as e:
        log_error("Encountered Error: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)

    obj = llm_builder.LLMBuilder(args.input_network, args.gguf_config,
                                 args.output_path, args.batch)
    args.input_network, args.quantization_overrides, input_layouts = obj.build_from_gguf()
    args.input_layout.extend(input_layouts)
    args.preserve_io_datatype = True

def main():
    parser = FrameworktoQNNArgParser()
    argsv2 = parser.parse_args()
    args = convert_args_v2_to_v1(argsv2)
    framework = infer_framework(args)

    if framework == 'gguf':
        build_onnx_graph_from_gguf(args)
        framework = 'onnx'

    try:

        # do lora conversion
        if hasattr(args, "lora_weight_list") and args.lora_weight_list != None:
            validator = get_validator(framework, args)
            set_optimization_args(args, framework)
            optimizer = IROptimizations(args)
            backend = DLCBackend(args)
            backend.initialize()
            num_graph_configs = get_num_graph_configs(args)
            # Backend Awareness
            backend_info_obj = BackendInfo.get_instance(args.backend, args.soc_model)
            args.enable_framework_trace = True
            converter = get_frontend_converter(framework, args, validator)
            ir_graph = converter.convert()
            optimized_graph = optimizer.optimize(ir_graph, backend_info_obj)
            prepared_optimized_graph = backend.prepare_py_graph(optimized_graph)

            cpp_graph = backend.get_ir_graph(prepared_optimized_graph)
            prepared_cpp_graph = backend.prepare_cpp_graph(prepared_optimized_graph, cpp_graph)

            lora_tensor_names = get_tensor_names_from_file(args.lora_weight_list)
            if lora_tensor_names:
                log_debug("Marking LoRA adapter weights to updateable")
                make_tensors_updateable(prepared_cpp_graph, lora_tensor_names)
            else:
                log_info("Input Model is part of LoRA Use case but the Model but has no LoRA Branches")

            backend.dlc_serializer.serialize(prepared_cpp_graph)
            backend.finish()
            log_info(code_to_message.get_progress_message("INFO_CONVERSION_SUCCESS"))


        # serialize one graph to dlc
        elif get_num_graph_configs(args) == 1:
            validator = get_validator(framework, args)
            set_optimization_args(args, framework)
            optimizer = IROptimizations(args)
            backend = DLCBackend(args)
            backend.initialize()
            num_graph_configs = get_num_graph_configs(args)
            # Backend Awareness
            backend_info_obj = BackendInfo.get_instance(args.backend, args.soc_model)
            converter = get_frontend_converter(framework, args, validator)
            ir_graph = converter.convert()
            optimized_graph = optimizer.optimize(ir_graph, backend_info_obj)
            backend.save(optimized_graph)

        # Network Specialization Case
        else:
            convert_with_multiple_shapes(args, framework)


        if (framework == 'onnx' or framework == 'tensorflow') and args.validate_models:
            try:
                results = validator.validate()
                for result in results:
                    log_info(result)
            except Exception as e:
                log_warning(
                    "Model conversion is completed but error "
                    "encountered during validation : {}".format(str(e))
                )

    except Exception as e:
        # When using torchscript-to-onnx converter,
        # we create an onnx file in a temporary location.
        # If the conversion fails, we should delete this temporary file
        # to avoid polluting the temporary directory.
        model_name = pathlib.Path(args.input_network).stem
        # Check if the file still exists and if this is an onnx file exported from Torchscript
        if os.path.exists(args.input_network) and model_name.endswith("_exportedOnnx_"):
            os.remove(args.input_network)
        log_error("Encountered Error: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)

    sys.exit(0)

if __name__ == '__main__':
    main()
