#!/usr/bin/env python3
# -*- mode: python -*-
# ==============================================================================
#
#  Copyright (c) 2021-2023 Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
# ==============================================================================

import sys
import traceback

try:
    import qti.aisw
except ImportError as ie1:
    print("Failed to find necessary python package")
    print(str(ie1))
    print("Please ensure that $SNPE_ROOT/lib/python is in your PYTHONPATH")
    sys.exit(1)

# Do Not Remove. Need to import this first so that the tvm override is initialized
import qti.tvm

# Common Imports
from qti.aisw.converters.qnn_backend.custom_ops.op_factory import QnnCustomOpFactory
from qti.aisw.converters.common.utils.converter_utils import log_error
from qti.aisw.converters.common.converter_ir.op_graph_optimizations import IROptimizations
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter

# TFLite Converter
from qti.aisw.converters.tflite.tflite_to_ir import TFLiteConverterFrontend

# Backend Imports
from qti.aisw.converters.qnn_backend.ir_to_dlc import DLCBackend as NativeBackend

from qti.aisw.converters.common.graph_optimizer import GraphOptimizer

class TFLiteToDLCArgParser(ArgParserWrapper):
    def __init__(self):
        super(TFLiteToDLCArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                                   conflict_handler='resolve',
                                                   parents=[TFLiteConverterFrontend.ArgParser(),
                                                            IROptimizations.ArgParser(),
                                                            NativeBackend.ArgParser(),
                                                            GraphOptimizer.ArgParser()])
        self.parser.description = 'Script to convert TFLite model into DLC'


def main():
    try:
        parser = TFLiteToDLCArgParser()
        args = parser.parse_args()

        converter = TFLiteConverterFrontend(args)
        graph = converter.convert()

        args.prepare_inputs_as_params = False
        args.force_prune_cast_ops = False

        optimizer = IROptimizations(args)
        optimized_graph = optimizer.optimize(graph)

        # save native model
        backend = NativeBackend(args)
        backend.save(optimized_graph)
    except Exception as e:
        log_error("Encountered Error: {}", str(e))
        traceback.print_exc()
        sys.exit(1)

    sys.exit(0)

if __name__ == '__main__':
    main()
