#!/usr/bin/env python3
# -*- mode: python -*-
#==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
#==============================================================================

import sys
import traceback
import argparse

# Common Imports
from qti.aisw.lora.lora_importer_app import apply_lora_updates
from qti.aisw.converters.common.utils import validation_utils
from qti.aisw.converters.common.utils.converter_utils import log_error, log_warning
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter

# @if AISW_ENABLE_IRGRAPH_TRANSFORMS
from qti.aisw.converters.common.graph_optimizer import GraphOptimizer
# @fi AISW_ENABLE_IRGRAPH_TRANSFORMS

class LoraArgParser(ArgParserWrapper):
    def __init__(self):
        super(LoraArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                                      conflict_handler='resolve',
                                                      parents=[
                                                               # @if AISW_ENABLE_IRGRAPH_TRANSFORMS
                                                               GraphOptimizer.ArgParser()
                                                               # @fi AISW_ENABLE_IRGRAPH_TRANSFORMS
                                                               ])

        self.add_required_argument('--lora_config',
                                   metavar="LORA_CONFIG_YAML",
                                   type=str,
                                   action=validation_utils.validate_filename_arg(must_exist=True),
                                   help='Path to the YAML config file for LoRA.')

        self.add_required_argument('--input_dlc', type=str,
                                    action=validation_utils.validate_filename_arg(must_exist=True),
                                    help='Path to the Float or Quantized DLC.')

        self.add_optional_argument("--input_network", "-i", type=str,
                                    action=validation_utils.validate_pathname_arg(must_exist=True),
                                    help="Path to the source ONNX model.")

        group = self.parser.add_mutually_exclusive_group()
        group.add_argument('--input_list', type=str,
                                    action=validation_utils.validate_filename_arg(must_exist=True),
                                    help='Path to a file specifying the input data. This file should be a plain text '
                                         'file, containing one or more absolute file paths per line. Each path is '
                                         'expected to point to a binary file containing one input in the "raw" format, '
                                         'ready to be consumed by the lora-importer without any further preprocessing. '
                                         'See documentation for more details.')

        group.add_argument('--enable_float_fallback', action='store_true', default=False, dest='float_fallback',
                                    help='Use this option to enable fallback to floating point (FP) instead of fixed point. '
                                         'If this option is enabled, then ``--input_list`` must not be provided. '
                                         'The external quantization encodings (encoding file/FakeQuant encodings) '
                                         'might be missing quantization parameters for some interim tensors. '
                                         'First it will try to fill the gaps by propagating across math-invariant '
                                         'functions. If the quantization params are still missing, then it will '
                                         'apply fallback to nodes to floating point.')
        group.add_argument('--float_fallback', action='store_true', default=False,
                           help=argparse.SUPPRESS)

        self.add_optional_argument("--debug", type=int, nargs='?', default=-1,
                                   help="Run the qairt-lora-importer in debug mode.")

        self.add_optional_argument('--output_dir', '-o', dest='output_dir', type=str,
                                   action=validation_utils.validate_filename_arg(must_exist=False,
                                                                                 is_directory=True,
                                                                                 create_missing_directory=True),
                                   help='Directory to store all output artifacts.'
                                        'This path will override all \'output_path\' fields in LORA_CONFIG_YAML.'
                                        'If --output_dir is not given, outputs will be saved in directory given by \'output_path\' in the LORA_CONFIG_YAML.'
                                        'If --output_dir and \'output_path\' are not given, outputs will be saved in a new directory, '
                                        'qairt_lora_importer_outputs/ located in the same directory as LORA_CONFIG_YAML.')
        self.add_optional_argument('--skip_validation', action='store_true', default=False,
                                    help=argparse.SUPPRESS)

        self.add_optional_argument('--dump_usecase_dlc', action='store_true', default=False,
                                   help=argparse.SUPPRESS)

        self.add_optional_argument("--dump_usecase_onnx", action="store_true", default=False,
                                   help=argparse.SUPPRESS)

        self.add_optional_argument("--skip_apply_graph_transforms", action="store_true", default=False,
                                   help=argparse.SUPPRESS)

def main():
    parser = LoraArgParser()
    args = parser.parse_args()

    if '--float_fallback' in sys.argv:
        log_warning("--float_fallback flag is deprecated, use --enable_float_fallback.")

    try:
        apply_lora_updates(args)

    except Exception as e:
        log_error("Encountered Error: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)
    sys.exit(0)

if __name__ == '__main__':
    main()
