#!/usr/bin/env python3
# =============================================================================
#
#  Copyright (c) 2020-2024 Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import sys
import traceback


from qti.aisw.converters.tensorflow.tf_to_ir import TFConverterFrontend
from qti.aisw.converters.tensorflow.util import ConverterError
from qti.aisw.converters.common.utils.converter_utils import log_error, log_warning, log_info
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter
from qti.aisw.converters.common.converter_ir.op_graph_optimizations import IROptimizations
from qti.aisw.converters.common.arch_linter.arch_linter import ArchLinter
from qti.aisw.converters.backend.ir_to_qnn import QnnConverterBackend
from qti.aisw.converters.backend.qnn_quantizer import QnnQuantizer
from qti.aisw.converters.qnn_backend.custom_ops.op_factory import QnnCustomOpFactory
from qti.aisw.converters.common.model_validator import Validator

from qti.aisw.converters.common.graph_optimizer import GraphOptimizer, OptimizationStage

class TFtoQNNArgParser(ArgParserWrapper):
    def __init__(self):
        super(TFtoQNNArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                               conflict_handler='resolve',
                                               parents=[TFConverterFrontend.ArgParser(),
                                                        IROptimizations.ArgParser(),
                                                        QnnQuantizer.ArgParser(),
                                                        QnnConverterBackend.ArgParser(),
                                                        ArchLinter.ArgParser(),
                                                        GraphOptimizer.ArgParser()
                                                        ])
        self.add_optional_argument("--validate_models", action="store_true",
                                    help="Validate the original TF model against optimized TF model.\n" \
                                        "Constant inputs with all value 1s will be generated and will be used \n" \
                                        "by both models and their outputs are checked against each other.\n" \
                                        "The % average error and 90th percentile of output differences will be calculated for this.\n" \
                                        "Note: Usage of this flag will incur extra time due to inference of the models.")

        self.parser.description = 'Script to convert TF model into QNN'


def main():
    parser = TFtoQNNArgParser()
    args = parser.parse_args()

    try:
        validator = None
        if args.validate_models:
            if args.converter_op_package_lib:
                log_warning("Model is having custom ops skipping validation.")
                args.validate_models = False
            else:
                validator = Validator()

        converter = TFConverterFrontend(args, custom_op_factory=QnnCustomOpFactory(), validator = validator)
        ir_graph = converter.convert()

        # Override optimizer flags for QNN backend
        args.perform_axes_to_spatial_first_order = False
        args.squash_box_decoder = True
        args.match_caffe_ssd_to_tf = True
        args.adjust_nms_features_dims = True
        args.extract_color_transform = True
        args.unroll_lstm_time_steps = True
        args.inject_cast_for_gather = True
        args.force_prune_cast_ops = False
        args.align_matmul_ranks = True
        args.handle_gather_negative_indices = True

        optimizer = IROptimizations(args)
        optimized_graph = optimizer.optimize(ir_graph)

        backend = QnnConverterBackend(args)
        backend.save(optimized_graph)

        # To be removed in future releases. Show deprecated msg for arch_checker.
        # archLinter = ArchLinter(args)
        # archLinter.run_linter(optimized_graph, backend)
        if args.arch_checker:
            log_warning("WARNING: The usage of --arch_checker from the conversion tool will be deprecated. \n\
                Use the following command to run the architecture checker.\n\
                Command: qnn-architecture checker -i <path to json> -b <path to bin> -o <output location>\n\
                -i <required> <path>/model.json\n\
                -b <optional> <path>/model.bin\n\
                -o <optional> <output_path>\n\
                Please refer to the documentation for further details.\n")

        if args.validate_models:
            try:
                results = validator.validate()
                for result in results:
                    log_info(result)
            except Exception as e:
                log_warning(
                    "Model conversion is completed but error "
                    "encountered during validation : {}".format(str(e))
                )

    except ConverterError as e:
        log_error("Conversion failed: {}".format(str(e)))
        sys.exit(1)
    except Exception as e:
        log_error("Encountered Error: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)


if __name__ == '__main__':
    main()
