#!/usr/bin/env python3
# =============================================================================
#
#  Copyright (c) 2021-2024 Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import sys
import traceback

import qti.tvm

from qti.aisw.converters.tflite.tflite_to_ir import TFLiteConverterFrontend
from qti.aisw.converters.common.utils.converter_utils import log_error, log_warning
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter
from qti.aisw.converters.common.converter_ir.op_graph_optimizations import IROptimizations
from qti.aisw.converters.common.arch_linter.arch_linter import ArchLinter
from qti.aisw.converters.backend.ir_to_qnn import QnnConverterBackend
from qti.aisw.converters.backend.qnn_quantizer import QnnQuantizer
from qti.aisw.converters.qnn_backend.custom_ops.op_factory import QnnCustomOpFactory

from qti.aisw.converters.common.graph_optimizer import GraphOptimizer, OptimizationStage

class TFLitetoQNNArgParser(ArgParserWrapper):
    def __init__(self):
        super(TFLitetoQNNArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                                   conflict_handler='resolve',
                                                   parents=[TFLiteConverterFrontend.ArgParser(),
                                                            IROptimizations.ArgParser(),
                                                            QnnQuantizer.ArgParser(),
                                                            QnnConverterBackend.ArgParser(),
                                                            ArchLinter.ArgParser(),
                                                            GraphOptimizer.ArgParser()
                                                            ])
        self.parser.description = 'Script to convert TFLite model into QNN'


def main():
    try:
        parser = TFLitetoQNNArgParser()
        args = parser.parse_args()

        converter = TFLiteConverterFrontend(args, custom_op_factory=QnnCustomOpFactory())

        graph = converter.convert()

        args.force_prune_cast_ops = False
        optimizer = IROptimizations(args)
        optimized_graph = optimizer.optimize(graph)

        backend = QnnConverterBackend(args)
        backend.save(optimized_graph)

        # To be removed in future releases. Show deprecated msg for arch_checker.
        # archLinter = ArchLinter(args)
        # archLinter.run_linter(optimized_graph, backend)
        if args.arch_checker:
            log_warning("WARNING: The usage of --arch_checker from the conversion tool will be deprecated. \n\
                Use the following command to run the architecture checker.\n\
                Command: qnn-architecture checker -i <path to json> -b <path to bin> -o <output location>\n\
                -i <required> <path>/model.json\n\
                -b <optional> <path>/model.bin\n\
                -o <optional> <output_path>\n\
                Please refer to the documentation for further details.\n")

    except Exception as e:
        log_error("Encountered Error: {}", str(e))
        traceback.print_exc()
        sys.exit(1)

    sys.exit(0)


if __name__ == '__main__':
    main()
