#!/usr/bin/env python3
# -*- mode: python -*-
# ==============================================================================
#
#  Copyright (c) 2024 Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
# ==============================================================================

import sys
import traceback
import numpy as np
from qti.aisw.converters.common.utils.converter_utils import log_error, log_warning
from qti.aisw.converters.common.utils.argparser_util import ArgParserWrapper, CustomHelpFormatter
from qti.aisw.converters.common.dlc_quantizer import DLCQuantizer
from qti.aisw.converters.common.utils.converter_utils import *
from qti.aisw.converters.common.utils import validation_utils
from qti.aisw.converters.common.utils import code_to_message
from qti.aisw.converters.common.backend_awareness import BackendInfo
from qti.aisw.converters.common.graph_optimizer import GraphOptimizer
import os

class QuantizerArgParser(ArgParserWrapper):
    def __init__(self):
        super(QuantizerArgParser, self).__init__(formatter_class=CustomHelpFormatter,
                                                 conflict_handler='resolve',
                                                 parents=[DLCQuantizer.ArgParser(),
                                                          BackendInfo.ArgParser(),
                                                          GraphOptimizer.ArgParser()
                                                          ])
        self.add_optional_argument("--debug", type=int, nargs='?', default=-1,
                                   help="Run the quantizer in debug mode.")


def main():
    parser = QuantizerArgParser()

    # Show help text incase no arguments are provided
    if(len(sys.argv) == 1):
        parser.parser.print_help()
        sys.exit(1)

    args = parser.parse_args()
    debug = args.debug
    if debug is None:
        debug = 3
    setup_logging(debug)
    quantizer_command = sanitize_args(args,
                                      args_to_ignore=['input_dlc', 'i', 'output_dlc',
                                                      'o'])
    args_dict = DLCQuantizer.ArgParser.validate_and_convert_args(args)

    # Backend Awareness
    backend_info_obj = BackendInfo.get_instance(args.backend, args.soc_model)

    try:
        dlc_quantizer = DLCQuantizer(input_dlc=args_dict['input_dlc'],
                                     output_dlc=args_dict['output_dlc'],
                                     input_list=args_dict['input_list'],
                                     float_fallback=args_dict['float_fallback'],
                                     param_quantizer=args_dict['param_quantizer'],
                                     act_quantizer=args_dict['act_quantizer'],
                                     algorithms=args_dict['algorithms'],
                                     bias_bitwidth=args_dict['bias_bitwidth'],
                                     act_bitwidth=args_dict['act_bitwidth'],
                                     weights_bitwidth=args_dict['weights_bitwidth'],
                                     float_bitwidth=args_dict['float_bitwidth'],
                                     float_bias_bitwidth=args_dict['float_bias_bitwidth'],
                                     ignore_encodings=args_dict['ignore_encodings'],
                                     use_per_channel_quantization=args_dict['use_per_channel_quantization'],
                                     use_per_row_quantization=args_dict['use_per_row_quantization'],
                                     enable_per_row_quantized_bias=args_dict['enable_per_row_quantized_bias'],
                                     preserve_io_datatype=args_dict['preserve_io_datatype'],
                                     use_native_input_files=args_dict['use_native_input_files'],
                                     use_native_output_files=args_dict['use_native_output_files'],
                                     restrict_quantization_steps=args_dict['restrict_quantization_steps'],
                                     use_dynamic_16_bit_weights=args_dict['use_dynamic_16_bit_weights'],
                                     disable_dynamic_16_bit_weights=args_dict['disable_dynamic_16_bit_weights'],
                                     pack_4_bit_weights=args_dict['pack_4_bit_weights'],
                                     keep_weights_quantized=args_dict["keep_weights_quantized"],
                                     adjust_bias_encoding=args_dict["adjust_bias_encoding"],
                                     act_quantizer_calibration=args_dict['act_quantizer_calibration'],
                                     param_quantizer_calibration=args_dict['param_quantizer_calibration'],
                                     act_quantizer_schema=args_dict['act_quantizer_schema'],
                                     param_quantizer_schema=args_dict['param_quantizer_schema'],
                                     percentile_calibration_value=args_dict['percentile_calibration_value'],
                                     use_aimet_quantizer=args_dict['use_aimet_quantizer'],
                                     op_package_lib=args_dict['op_package_lib'],
                                     disable_legacy_quantizer=args_dict['disable_legacy_quantizer'],
                                     dump_encoding_json=args_dict['dump_encoding_json'],
                                     include_data_invariant_ops=args_dict['include_data_invariant_ops'],
                                     aimet_config=args_dict['config_file'],
                                     backend_info_obj=backend_info_obj,
                                     export_stripped_dlc=args_dict['export_stripped_dlc'],
                                     use_quantize_v2=args_dict['use_quantize_v2'],
                                     should_compute_statics=args_dict['calc_static_encodings'],
                                     log_file_name=args_dict['quantizer_log'],
                                     log_level=args_dict['quantizer_log_level'],
                                     )
        dlc_quantizer.quantize()
        dlc_quantizer.save(quantizer_command)

    except Exception as e:
        log_error("Encountered Error: {}".format(str(e)))
        traceback.print_exc()
        sys.exit(1)


if __name__ == '__main__':
    main()
