#!/usr/bin/env python3
# -*- mode: python -*-
# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================

import sys
import traceback
from qti.aisw.lora.lora_mapper_app import resolve_attach_point_name, LoraMapperAppConfig
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter
from qti.aisw.converters.common.utils.converter_utils import log_error, setup_logging
from qti.aisw.converters.common.utils import validation_utils
from qti.aisw.converters.common.utils.io_utils import get_default_output_directory

class LoraMapperArgParser(ArgParserWrapper):
    def __init__(self):
        super(LoraMapperArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                                   conflict_handler='resolve',
                                                   parents=[])
        self.add_required_argument('--lora_config',
                                   metavar="LORA_CONFIG_YAML",
                                   type=str,
                                   action=validation_utils.validate_filename_arg(must_exist=True),
                                   help='Path to the YAML config file for LoRA.')

        self.add_optional_argument('--output_dir', '-o', dest='output_dir', type=str,
                                   action=validation_utils.validate_filename_arg(must_exist=False,
                                                                                 is_directory=True,
                                                                                 create_missing_directory=True),
                                   help='Path to store the updated adapter configs and updated lora config. '
                                        'Output files will have the same name as the original file but with suffix, "_updated". '
                                        'If output_dir is not given, outputs will be saved in a new directory, '
                                        'qairt_lora_mapper_outputs/ located in the same directory as LORA_CONFIG_YAML.')

        self.add_optional_argument("--debug", action='store_true', default=False,
                                   help='Run the mapper in debug mode. '
                                        'If set, debug logs will be printed during execution and the mapper will output lora_mapper_debug_info.json '
                                        'which contains mappings of pytorch modules to onnx attach points for each adapter.')





def main():
    parser = LoraMapperArgParser()
    args = parser.parse_args()
    app_config = LoraMapperAppConfig(
        lora_config=args.lora_config,
        output_dir=args.output_dir if args.output_dir else get_default_output_directory(args.lora_config, "qairt_lora_mapper_outputs"),
        debug=args.debug
    )

    setup_logging(0 if app_config.debug else -1)
    try:
        resolve_attach_point_name(app_config)
    except Exception as e:
        log_error("Encountered Error: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)

if __name__ == '__main__':
    main()
