#!/usr/bin/env python3
# -*- mode: python -*-
# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import sys
import traceback
import os

# import numpy before qti.aisw.converters.xxxx modules
import numpy
import pathlib
import shutil
import json

# modeltools is present in common as well as dlc_utils
# try importing from common first (currently used by QNN) and if not found import from dlc_utils (used by SNPE)
# TODO: remove modeltools from dlc_utils and update all SNPE tools to use modeltools from common
try:
    from qti.aisw.converters.common import modeltools
except ImportError as ie1:
    from qti.aisw.dlc_utils import modeltools

# Common Imports
from qti.aisw.converters.common import ir_graph
from qti.aisw.converters.common.converter_ir.op_graph import QuantUpdatableMode
from qti.aisw.converters.common.utils import code_to_message
from qti.aisw.converters.common.utils.converter_utils import *
from qti.aisw.converters.common.utils.lora_utils import make_tensors_updateable
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter
from qti.aisw.converters.common.converter_ir.op_graph_optimizations import IROptimizations
from qti.aisw.converters.common.converter_ir import op_adapter
from qti.aisw.converters.qnn_backend.ir_to_dlc import DLCBackend
from qti.aisw.converters.qnn_backend.custom_ops.op_factory import QnnCustomOpFactory
from qti.aisw.converters.common.qairt_converter_arguments import (
    convert_args_v2_to_v1,
    ExperimentalFeature,
    QairtConverterFrontendArgParser
)
from qti.aisw.converters.common.utils.multi_graph import IrStaticTensorSet
from qti.aisw.converters.common.graph_property_setter import GraphPropertySetter
from qti.aisw.converters.common.model_validator import Validator
from qti.aisw.converters.common.backend_awareness import BackendInfo
from qti.aisw.converters.common.input_shape import InputShapeArgParser, InputShapeInfo
from qti.aisw.converters.common.utils.validation_utils import validate_tensor_names_in_graph

from qti.aisw.converters.common.graph_optimizer import GraphOptimizer, OptimizationStage

class FrameworktoQNNArgParser(ArgParserWrapper):
    def __init__(self):
        super(FrameworktoQNNArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                                      conflict_handler='resolve',
                                                      parents=[QairtConverterFrontendArgParser(),
                                                               IROptimizations.ArgParserv2(),
                                                               DLCBackend.ArgParserv2(),
                                                               BackendInfo.ArgParser(),
                                                               GraphOptimizer.ArgParser()
                                                               ])

    def validate_args(self, args):
        has_quant_updatable_mode = hasattr(args, "quant_updatable_mode") and args.quant_updatable_mode != None
        has_lora_weight_list = hasattr(args, "lora_weight_list") and args.lora_weight_list != None

        if has_lora_weight_list:
            input_shape_info = InputShapeArgParser(args.input_dim).input_shape_info
            if input_shape_info.has_dynamic_shapes:
                raise Exception("Dynamic tensors are not supported with LoRA with this tool.")

        if has_quant_updatable_mode and not has_lora_weight_list:
            raise ValueError("--quant_updatable_mode requires --lora_weight_list value.")

def set_optimization_args(args, framework):
    # TODO: Align optimizations for all frameworks
    if framework == 'onnx':
        args.expand_gru_op_structure = True
        args.unroll_gru_time_steps = True
        args.expand_sparse_op_structure = False

    if (framework == 'onnx' and not args.use_onnx_relay) or framework == 'pytorch':
        if args.enable_Layout_Transform_v1:
            args.perform_layout_transformation = False
            args.perform_axes_to_spatial_first_order = True
        else:
            args.perform_layout_transformation = True
            args.perform_axes_to_spatial_first_order = False
        args.preprocess_roi_pool_inputs = True

    if framework == 'onnx' or framework == 'tensorflow':
        args.unroll_lstm_time_steps= True
        args.align_matmul_ranks = True
        args.handle_gather_negative_indices = True

    if framework == 'tensorflow' or framework == 'pytorch':
        args.match_caffe_ssd_to_tf = True

    # Enable/Disable following optimizations for onnx, tf, pytorch
    if framework != 'tflite':
        args.squash_box_decoder = True
        args.adjust_nms_features_dims = True
        args.extract_color_transform = True
        args.inject_cast_for_gather = True
        args.force_prune_cast_ops = False

def get_frontend_converter(framework, args, validator, backend_info_obj=None):
    if framework == 'onnx':
        if not args.use_onnx_relay:
            from qti.aisw.converters.onnx.onnx_to_ir import OnnxConverterFrontend
            return OnnxConverterFrontend(args, custom_op_factory=QnnCustomOpFactory(), validator=validator, backend_info_obj=backend_info_obj)
        else:
            try:
                # use onnx-relay-converter flow
                from qti.aisw.converters.onnx.onnx_to_ir_relay import OnnxRelayConverterFrontend
                return OnnxRelayConverterFrontend(args, custom_op_factory=QnnCustomOpFactory())
            except Exception as e:
                raise Exception("--use_onnx_relay is not available. Please remove --use_onnx_relay in converter command.")
    elif framework == "tensorflow":
        from qti.aisw.converters.tensorflow.tf_to_ir import TFConverterFrontend
        from qti.aisw.converters.tensorflow.util import ConverterError
        if not args.input_dim or not args.out_names:
            raise Exception("--source_model_input_shape and --out_tensor_node are required for TensorFlow conversion")
        return TFConverterFrontend(args, custom_op_factory=QnnCustomOpFactory(), validator=validator)
    elif framework == "tflite":
        from qti.aisw.converters.tflite.tflite_to_ir import TFLiteConverterFrontend
        return TFLiteConverterFrontend(args, custom_op_factory=QnnCustomOpFactory())
    elif framework == "pytorch":
        from qti.aisw.converters.pytorch.pytorch_to_ir import PyTorchConverterFrontend
        from qti.aisw.converters.relay.custom_ops.utils.pytorch_helpers import PytorchCustomOpFactory
        if not args.input_dim:
            raise Exception("--source_model_input_shape is required for PyTorch conversion")
        return PyTorchConverterFrontend(args, custom_op_factory=PytorchCustomOpFactory())
    else:
        raise Exception(f"unrecognized framework {framework}")

def get_num_tensor_configs(tensor_configs):
    # for when there is only one tensor config e.g. tensor_configs == (1,2,3)
    if isinstance(tensor_configs[0], int):
        return 1

    # for when input_dims is passed individually via CLI
    elif isinstance(tensor_configs, str):
        return 1

    # for when there is multiple tensor configs e.g. tensor_configs == ((1,2,3), (4,5,6))
    else:
        return len(tensor_configs)

def get_num_graph_configs(args):
    def validate_num_configs_is_1_or_n(num_tensor_configs_seen):
        error_message = "Error: Number of tensor configurations can either be 1 or N. \
                       You specified the following number of tensor configurations: {}" \
            .format(num_tensor_configs_seen)
        if len(num_tensor_configs_seen) > 2:
            log_error(error_message)
        elif len(num_tensor_configs_seen) == 2:
            if 1 not in num_tensor_configs_seen:
                log_error(error_message)

    if args.input_dim is None:
        return 1

    num_tensor_configs_seen = set()
    for tensor_name, tensor_configs in args.input_dim:
        num_tensor_configs = get_num_tensor_configs(tensor_configs)
        num_tensor_configs_seen.add(num_tensor_configs)
    validate_num_configs_is_1_or_n(num_tensor_configs_seen)
    return max(num_tensor_configs_seen)

def get_graph_configs(args):
    def convert_dimensions_to_string(dims):
        return ",".join([str(dim) for dim in dims])

    num_configurations = get_num_graph_configs(args)
    configurations = []

    for i in range(num_configurations):
        configuration = []
        for tensor_name, tensor_configs in args.input_dim:
            if get_num_tensor_configs(tensor_configs) > 1:
                tensor_dims = convert_dimensions_to_string(tensor_configs[i])
            else:
                tensor_dims = convert_dimensions_to_string(tensor_configs)
            configuration.append([tensor_name, tensor_dims])
        configurations.append(configuration)
    return configurations

def set_graph_configs(args, config):
    args.input_dim = config

def infer_framework(args):
    input_model_to_framework = {'.onnx': 'onnx', '.pb': 'tensorflow', '.pt': 'pytorch', '.tflite': 'tflite', '.gguf': 'gguf'}
    model_path, model_ext = os.path.splitext(args.input_network)

    # tensorflow2 takes as input a folder which would have the ".pb" file
    if model_ext not in input_model_to_framework:
        model_files = os.listdir(model_path)
        for file in model_files:
            file_ext = os.path.splitext(file)[1]
            if file_ext == '.pb':
                model_ext = '.pb'

    if model_ext not in input_model_to_framework:
        raise Exception("Invalid model format specified. Supported types are .onnx/.pb/.tflite/.pt")
    framework = input_model_to_framework[model_ext]
    return framework

def get_validator(framework, args):
    validator = None
    if (framework == 'onnx' or framework == 'tensorflow') and args.validate_models:
        if args.converter_op_package_lib:
            log_warning("Model is having custom ops skipping validation.")
            args.validate_models = False
        else:
            validator = Validator()
    return validator

def ts_to_onnx(args):
    from qti.aisw.converters.pytorch.torchscript_to_onnx import to_onnx
    try:
        onnx_model_name = to_onnx(args)
    except Exception as e:
        raise RuntimeError("Converter would convert Torchscript into Onnx, but it failed to convert to onnx!") from e

    if args.output_path is None:
        args.output_path = args.input_network.rpartition('.')[0] + '.dlc'

    args.input_network = onnx_model_name
    args.perform_sequence_construct_optimizer = True

def convert_with_multiple_shapes(args, framework: str):
    validator = get_validator(framework, args)
    set_optimization_args(args, framework)
    optimizer = IROptimizations(args)
    backend = DLCBackend(args)
    backend.initialize()
    # Backend Awareness
    backend_info_obj = BackendInfo.get_instance(args.backend, args.soc_model)
    graph_configs = get_graph_configs(args)

    # In network specialization flow, we will avoid check if the shared context
    # static tensors is already present in DLC
    enable_tensor_deduplication = False

    # The enable_tensor_deduplication flag will enable serializer to look for shared context
    # static tensors data in DLC
    if hasattr(args, 'enable_tensor_deduplication') and args.enable_tensor_deduplication:
        enable_tensor_deduplication = True

    for config in graph_configs:
        set_graph_configs(args, config)
        converter = get_frontend_converter(framework, args, validator, backend_info_obj)
        python_ir_graph = converter.convert()
        optimized_graph = optimizer.optimize(python_ir_graph, backend_info_obj)
        backend.serialize(optimized_graph, network_specialization = True, enable_tensor_deduplication = enable_tensor_deduplication,
                          is_qairt = True, backend_info_obj=backend_info_obj)
        del optimized_graph
        del python_ir_graph
        del converter

    backend.finish()

def build_onnx_graph_from_gguf(args):
    try:
        from qti.aisw.converters.gguf_builder import GGUFBuilder
    except Exception as e:
        log_error("Encountered Error: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)

    builder = GGUFBuilder(args.input_network, args.gguf_config,
                          args.output_path, args.batch)
    args.input_network, args.quantization_overrides = builder.build_from_gguf()


def main():
    parser = FrameworktoQNNArgParser()

    # Show help text incase no arguments are provided
    if(len(sys.argv) == 1):
        parser.parser.print_help()
        sys.exit(1)

    argsv2 = parser.parse_args()
    args = convert_args_v2_to_v1(argsv2)
    framework = infer_framework(args)

    if args.dump_exported_onnx and framework != 'pytorch':
        raise NotImplementedError(f"--dump_exported_onnx only can be used in Torchscript models, but got {framework.capitalize()} model!")

    if framework == 'pytorch':
        ts_to_onnx(args)
        framework = 'onnx'

    if framework == 'gguf':
        build_onnx_graph_from_gguf(args)
        framework = 'onnx'

    try:

        # do lora conversion
        if hasattr(args, "lora_weight_list") and args.lora_weight_list != None:
            validator = get_validator(framework, args)
            set_optimization_args(args, framework)
            optimizer = IROptimizations(args)
            backend = DLCBackend(args)
            backend.initialize()
            num_graph_configs = get_num_graph_configs(args)
            # Backend Awareness
            backend_info_obj = BackendInfo.get_instance(args.backend, args.soc_model)
            converter = get_frontend_converter(framework, args, validator, backend_info_obj)
            py_ir_graph = converter.convert()
            optimized_graph = optimizer.optimize(py_ir_graph, backend_info_obj)
            prepared_optimized_graph = backend.prepare_py_graph(optimized_graph)

            lora_tensor_names = prepared_optimized_graph.get_all_updatable_tensors(report_error=not args.skip_validation)
            validate_tensor_names_in_graph(lora_tensor_names, prepared_optimized_graph, args.lora_weight_list, args.skip_validation)

            if hasattr(args, "quant_updatable_mode") and args.quant_updatable_mode != None:
                quant_updatable_mode = QuantUpdatableMode(args.quant_updatable_mode)
            else:
                quant_updatable_mode = None
            graph_property_setter = GraphPropertySetter()
            graph_property_setter.set_graph_properties(prepared_optimized_graph, quant_updatable_mode, lora_tensor_names)

            cpp_graph = backend.get_ir_graph(prepared_optimized_graph)
            prepared_cpp_graph = backend.prepare_cpp_graph(prepared_optimized_graph, cpp_graph)
            args_dict = GraphOptimizer.ArgParser.convert_args(args)
            graph_optimizer = GraphOptimizer(args_dict)
            graph_optimizer.optimize(prepared_cpp_graph, [OptimizationStage.PostLayout])
            prepared_cpp_graph = backend.quantize_cpp_graph(prepared_optimized_graph, prepared_cpp_graph, args, backend_info_obj)

            if lora_tensor_names:
                log_debug("Marking LoRA adapter weights to updatable")
                prepared_cpp_graph = make_tensors_updateable(prepared_cpp_graph, backend.applied_float_fallback, lora_tensor_names)
            else:
                log_info("Input Model is part of LoRA Use case but the Model but has no LoRA Branches")

            backend.dlc_serializer.serialize(prepared_cpp_graph)

            if hasattr(args, "quant_updatable_mode") and args.quant_updatable_mode == "none" and \
                not (hasattr(args, "disable_transform_tracking") and args.disable_transform_tracking):
                lora_metadata_dict = populate_lora_metadata_json_schema(prepared_cpp_graph, prepared_optimized_graph.lora_tensor_names)
                lora_metadata_binary_obj = json.dumps(lora_metadata_dict, indent=2).encode('utf-8')
                backend.dlc_serializer.add_record_from_buffer(lora_metadata_binary_obj, modeltools.DlcRecordType.LORA_CONVERTER_METADATA)

            backend.finish()
            log_info(code_to_message.get_progress_message("INFO_CONVERSION_SUCCESS"))


        # serialize one graph to dlc
        elif get_num_graph_configs(args) == 1:
            validator = get_validator(framework, args)
            set_optimization_args(args, framework)
            optimizer = IROptimizations(args)
            backend = DLCBackend(args)
            backend.initialize()
            num_graph_configs = get_num_graph_configs(args)
            # Backend Awareness
            backend_info_obj = BackendInfo.get_instance(args.backend, args.soc_model)
            converter = get_frontend_converter(framework, args, validator, backend_info_obj)
            py_ir_graph = converter.convert()
            optimized_graph = optimizer.optimize(py_ir_graph, backend_info_obj)
            backend.initialize()
            backend.serialize(optimized_graph, is_qairt=True, backend_info_obj=backend_info_obj)
            backend.finish()

        # Network Specialization Case
        else:
            convert_with_multiple_shapes(args, framework)

        model_name = pathlib.Path(args.input_network).stem
        # Check if this is an onnx file exported from Torchscript
        if model_name.endswith("_exportedOnnx_"):
            # If the dump_exported_onnx is on, then move the exported onnx model
            # from the temp folder to the output folder and rename onnx model's name.
            # Else, delete the temporary storage location of the exported onnx model.
            if args.dump_exported_onnx:
                output_folder = os.path.dirname(args.output_path)
                os.makedirs(os.path.join(output_folder, "exportedOnnx"), exist_ok=True)
                shutil.move(
                    args.input_network,
                    os.path.join(
                        output_folder,
                        "exportedOnnx",
                        f"{model_name.split('_')[0]}.onnx",
                    ),
                )
            else:
                # Remove the temp file
                os.remove(args.input_network)

        if (framework == 'onnx' or framework == 'tensorflow') and args.validate_models:
            try:
                results = validator.validate()
                for result in results:
                    log_info(result)
            except Exception as e:
                log_warning(
                    "Model conversion is completed but error "
                    "encountered during validation : {}".format(str(e))
                )

    except Exception as e:
        # When using torchscript-to-onnx converter,
        # we create an onnx file in a temporary location.
        # If the conversion fails, we should delete this temporary file
        # to avoid polluting the temporary directory.
        model_name = pathlib.Path(args.input_network).stem
        # Check if the file still exists and if this is an onnx file exported from Torchscript
        if os.path.exists(args.input_network) and model_name.endswith("_exportedOnnx_"):
            os.remove(args.input_network)
        log_error("Encountered Error: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)

    sys.exit(0)

if __name__ == '__main__':
    main()
