#!/usr/bin/env python3
# =============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import argparse
import time

from qti.aisw.genai.qnn_genai_transformer_composer_backend import *

def parse_args() -> dict[str, Any]:
    parser = argparse.ArgumentParser(description = "Convert a LLaMa model to binary file")
    parser.add_argument("--quantize",              choices = ["Z4", "Z4_FP16", "Q4", "Z8"], help = "Quantization type")
    parser.add_argument("--export_tokenizer_json", action = "store_true",       help = "Export the tokenizer as a HuggingFace tokenizer.json file")
    parser.add_argument("--outfile",               type = Path,                 help = "Path to write to; default: path provided in --model parameter")
    parser.add_argument("--config_file",           type = Path,                 help = "Path to base model configuration.json for Generic Transformer")
    parser.add_argument("--model",                 type = Path,                 help = "Path to the base model directory")
    parser.add_argument("--lora",                  type = Path,                 nargs='+',    help="Paths to the LoRA adapter directories")
    parser.add_argument("--lm_head_precision",     choices = ["FP_32"],         help = "Quantize the lm_head tensor; default: quantization type provided in --quantize")
    parser.add_argument("--dump_lut",              action = "store_true",       help = "Dumps the token embedding weight as LUT.bin")
    args = parser.parse_args()

    # store all the arguments as a key vale pair that can be passed to the composer
    arguments = {}
    arguments['quantize'] = args.quantize
    arguments['export_tokenizer_json'] = args.export_tokenizer_json if args.export_tokenizer_json else False
    arguments['dump_lut'] = args.dump_lut if args.dump_lut else False
    arguments['outfile'] = args.outfile
    arguments['config_file'] = args.config_file
    arguments['model'] = args.model
    arguments['lora'] = args.lora
    arguments['lm_head_precision'] = args.lm_head_precision
    return arguments

if __name__ == '__main__':
    # parse the sys args
    kwargs = parse_args()

    start = time.time()
    run_composer(**kwargs)
    end = time.time()

    print(f"Time {(end - start):8.4f} s")