#!/usr/bin/env python3
# -*- mode: python -*-
# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================

import sys
import traceback
from qti.aisw.lora.lora_model_creator_app import *
from qti.aisw.converters.common.utils.io_utils import get_default_output_directory
from qti.aisw.converters.common.utils import validation_utils
from qti.aisw.converters.common.utils.converter_utils import *
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter


class LoraModelCreatorArgParser(ArgParserWrapper):
    def __init__(self):
        super(LoraModelCreatorArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                                   conflict_handler='resolve',
                                                   parents=[])
        self.add_required_argument('--lora_config',
                                   metavar="LORA_CONFIG_YAML",
                                   type=str,
                                   action=validation_utils.validate_filename_arg(must_exist=True),
                                   help='Path to the YAML config file for LoRA.')

        self.add_optional_argument('--output_dir', '-o', dest='output_dir', type=str,
                                   action=validation_utils.validate_filename_arg(must_exist=False,
                                                                                 is_directory=True,
                                                                                 create_missing_directory=True),
                                   help='Path to store the output of the qairt-lora-model-creator tool. '
                                        'If --output_dir is not given, outputs will be saved in a new directory, '
                                        'qairt_lora_model_creator_outputs/ located in the same directory as LORA_CONFIG_YAML.')

        self.add_optional_argument("--quant_updatable_mode",
                                    type=str,
                                    default="adapter_only",
                                    choices=["none", "adapter_only", "all"],
                                    help="Specify whether/for which tensors the quantization encodings change " \
                                         "across use-cases. In none mode, no quantization encodings are updatable. " \
                                         "In adapter_only mode quantization encodings for " \
                                         "only lora/adapter branch (Conv->Mul->Conv) change across use-case, "
                                         "the base branch quantization encodings remain the same. " \
                                         "In all mode, all quantization encodings are updatable.")

        self.add_optional_argument("--debug", type=int, nargs='?', default=-1,
                                   help="Run the qairt-lora-creator in debug mode.")

        self.add_optional_argument("--skip_validation", action="store_true", default=False,
                                   help=argparse.SUPPRESS)

        self.add_optional_argument("--dump_usecase_onnx", action="store_true", default=False,
                                   help=argparse.SUPPRESS)


def main():
    parser = LoraModelCreatorArgParser()
    args = parser.parse_args()
    if not args.output_dir:
        args.output_dir = get_default_output_directory(args.lora_config, "qairt_lora_model_creator_outputs")

    debug = args.debug
    if debug is None:
        debug = 0
    setup_logging(debug)

    try:
        lora_model_creator_app = LoraModelCreatorApp(
            args.lora_config,
            args.output_dir,
            args.skip_validation,
            args.quant_updatable_mode,
            args.dump_usecase_onnx
        )

        lora_model_creator_app.run()

    except Exception as e:
        log_error("Encountered Error: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)


if __name__ == '__main__':
    main()
