#!/usr/bin/env python3
# =============================================================================
#
#  Copyright (c) 2015-2016,2018-2020, 2022-2023 Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================


import traceback

import tensorflow as tf
import sys

try:
    import qti.aisw
except ImportError as ie1:
    print("Failed to find necessary python package")
    print(str(ie1))
    print("Please ensure that $SNPE_ROOT/lib/python is in your PYTHONPATH")
    sys.exit(1)

from qti.aisw.converters.qnn_backend.custom_ops.op_factory import QnnCustomOpFactory
from qti.aisw.converters.tensorflow.tf_to_ir import TFConverterFrontend
from qti.aisw.converters.tensorflow.util import ConverterError
from qti.aisw.converters.qnn_backend.ir_to_dlc import DLCBackend as NativeBackend
from qti.aisw.converters.common.converter_ir.op_graph_optimizations import IROptimizations
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter
from qti.aisw.converters.common.utils.converter_utils import *

from qti.aisw.converters.common.graph_optimizer import GraphOptimizer


class TFtoDLCArgParser(ArgParserWrapper):
    def __init__(self):
        super(TFtoDLCArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                               conflict_handler='resolve',
                                               parents=[TFConverterFrontend.ArgParser(),
                                                        IROptimizations.ArgParser(),
                                                        NativeBackend.ArgParser(),
                                                        GraphOptimizer.ArgParser()])
        self.parser.description = 'Script to convert TF model into DLC'


def main():
    parser = TFtoDLCArgParser()
    args = parser.parse_args()

    try:
        converter = TFConverterFrontend(args, custom_op_factory=QnnCustomOpFactory())
        ir_graph = converter.convert()

        args.prepare_inputs_as_params = False
        args.perform_axes_to_spatial_first_order = False
        args.match_caffe_ssd_to_tf = True
        args.adjust_nms_features_dims = True
        args.extract_color_transform = True
        args.unroll_lstm_time_steps = True
        args.inject_cast_for_gather = True
        args.force_prune_cast_ops = False
        args.align_matmul_ranks = True
        args.handle_gather_negative_indices = True
        args.squash_box_decoder = True

        optimizer = IROptimizations(args)
        optimized_graph = optimizer.optimize(ir_graph)

        # save native model
        backend = NativeBackend(args)
        backend.save(optimized_graph)
    except ConverterError as e:
        log_error("Conversion FAILED: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)
    except Exception as e:
        log_error("Conversion FAILED!")
        raise e


if __name__ == '__main__':
    main()
