#!/usr/bin/env python3
##############################################################################
#
# Copyright (c) Qualcomm Technologies, Inc.
# All Rights Reserved.
# Confidential and Proprietary - Qualcomm Technologies, Inc.
#
##############################################################################

import argparse
import os
import shutil
import sys


if os.path.isdir(os.path.abspath(os.path.join(sys.path[0], '../python/'))):
    # When evaluator is run from bin directory of source append its path instead of SDK
    sys.path.insert(0, os.path.abspath(os.path.join(sys.path[0], '../python/')))
else:
    sys.path.insert(0, os.path.join(os.environ['QNN_SDK_ROOT'], 'lib', 'python'))

import qti.aisw.accuracy_evaluator.common.exceptions as ce
from qti.aisw.accuracy_evaluator.evaluator_module import (
    EvaluatorInputs,
    EvaluatorModule,
)
from qti.aisw.accuracy_evaluator.qacc import qacc_logger
from qti.aisw.accuracy_evaluator.qacc.config_definitions import EvaluatorPipelineConfig


def remove_work_dir(work_dir, prompt):
    """Deletes temp directory before execution starts."""
    qacc_logger.warning('Directory {} will be deleted if already exists. Take backup before '
                        'execution.'.format(work_dir))
    user_input = input('Do you want to start execution? (yes/no) :').lower() if prompt else 'y'
    if user_input in ['yes', 'y']:
        temp_dir = os.path.join(work_dir)
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
            qacc_logger.warning("Removing " + str(temp_dir) + "\n")
    else:
        raise ce.UserAbort("User intends to discontinue the execution")


# Source:
# https://stackoverflow.com/questions/72741663/argument-parser-from-a-pydantic-model
def add_arguments(parser, pydantic_class):
    """Add Pydantic class elements to ArgumentParser"""
    fields = pydantic_class.model_fields
    for name, field in fields.items():
        action = None
        if "bool" in str(field.annotation):
            action = "store_true"
        parser.add_argument(
            f"-{name}",
            dest=name,
            action=action,
            default=field.default if (not field.is_required() and field.default is not None
                                    and "bool" not in str(field.annotation))
                                    else argparse.SUPPRESS,
            help=field.description,
            required=field.is_required(),
        )
    parser.add_argument("-debug", dest="debug", action="store_true", required=False,
                        help="Enable debug logs on console and the file.")
    parser.add_argument(
        "-set_global", dest="set_global", action="append", nargs="+", required=False,
        help="Option used to set a global variable. It can be repeated. " +
        "Example: -set_global count:10 -set_global calib:5")


def transform_set_global(set_global):
    """Convert the set_global value List[List[Any]] given by parse_args() function to Dict[str,str]
    required by Evaluator Module
    """
    global_dict = {}
    for str_list in set_global or []:
        try:
            s = (''.join(str_list)).strip()
            elem = s.split(':')
            assert (len(elem) == 2)
            key = elem[0].strip()
            val = elem[1].strip()
            global_dict[key] = val
        except Exception:
            raise ce.InvalidSetGlobalFormat("set_global arguments are not in key:val format" +
                                         str(set_global))
    return (global_dict)


def dump_evaluator_outputs(evaluator_outputs):
    """Log the evaluator outputs to console"""
    qacc_logger.info("metric_results = " + str(evaluator_outputs.metric_results))
    qacc_logger.info("qacc_log = " + str(evaluator_outputs.qacc_log))
    qacc_logger.info("config_yaml = " + str(evaluator_outputs.config_yaml))


def main():
    """Entry point for evaluator from CLI. Parses CLI arguments, prepares the working directory,
    and initiates the evaluation process.
    """
    # Parse the CLI arguments
    parser = argparse.ArgumentParser(description='qairt-accuracy-evaluator options',
                           formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    add_arguments(parser, EvaluatorInputs)
    args = parser.parse_args()

    # Prepare work directory to be a NewPath
    work_dir = EvaluatorInputs.model_fields["work_dir"].default
    if ("work_dir" in vars(args)):
        work_dir = args.work_dir
    silent = ("silent" in vars(args))
    args.use_memory_plugins = ("use_memory_plugins" in vars(args))
    args.use_memory_pipeline = ("use_memory_pipeline" in vars(args))
    use_memory_plugins = args.use_memory_plugins or args.use_memory_pipeline
    remove_work_dir(work_dir, not silent)

    # Set global and debug are CLI arguments, but not required for Evaluator Module
    set_global = args.set_global
    debug = args.debug
    del args.set_global
    del args.debug

    # Parse the set global value and convert it to the type required by the EvaluatorInputs
    set_global = transform_set_global(set_global)

    # Call Evaluator Config Parser
    args.config = EvaluatorPipelineConfig(config_path=args.config, set_global=set_global,
                                    use_memory_plugins=use_memory_plugins)
    evaluator_inputs = EvaluatorInputs(**vars(args))

    # Create evaluator module object and call evaluate() API
    evaluator = EvaluatorModule()
    if debug:
        evaluator.enable_debug()
    evaluator_outputs = evaluator.evaluate(evaluator_inputs)

    dump_evaluator_outputs(evaluator_outputs)


if __name__ == "__main__":
    main()
