#!/usr/bin/env python3
# =============================================================================
#
#  Copyright (c) 2022 Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================
import argparse
import logging
import os
import sys
import json

try:
    from qti.aisw.converters.common import json_deserializer as py_ir_json_deserializer
    from qti.aisw.arch_checker.qnn_arch_checker import QnnArchChecker
    from qti.aisw.arch_checker.qnn_modifier import QnnModifier
    from qti.aisw.arch_checker.arch_checker_cmd_options import MainCmdOptions
    import qti.aisw.arch_checker.config as config
except ImportError as ie:
    print("Failed to find necessary package:")
    print(str(ie))
    print("Please ensure that $QNN_SDK_ROOT/lib/python is in your PYTHONPATH")
    sys.exit(1)


def main():
    system_args = sys.argv[1:]
    command = os.path.basename(sys.argv[0])
    args = MainCmdOptions(command, system_args).parse()

    logging.basicConfig(format='%(asctime)s - %(lineno)d - %(levelname)s - %(message)s', level=logging.DEBUG)
    logger = logging.getLogger()

    if not os.path.exists(args.input_json):
        logger.error("Cannot find archive json file " + args.input_json)
        sys.exit(-1)
    if args.bin and not os.path.exists(args.bin):
        logger.error("Cannot find archive bin file " + args.bin)
        sys.exit(-1)

    try:
        json_deserializer = py_ir_json_deserializer.IrJsonDeserializer()
        json_deserializer.open(args.input_json)
        if args.bin:
            json_deserializer.open(args.bin)
        ir_graph_ptr = json_deserializer.deserialize()
    except Exception as err:
        logger.error("Unable to run Architecture checker on this model")
        if "Missing static data" in str(err):
            logger.error("Please provide the bin file for running the architecture checker on this model.")
        sys.exit(-1)

    constraints_file = os.path.join(os.environ.get('QNN_SDK_ROOT'),
                config.CONSTRAINTS_PATH)
    if not os.path.exists(constraints_file):
        logger.error("Failed to find necessary file. Please ensure $QNN_SDK_ROOT/lib/python/qti/aisw/arch_checker/constraints.json exists")
        sys.exit(-1)
    with open(constraints_file) as f:
        constraints = json.load(f)

    out_file_base = os.path.splitext(os.path.abspath(args.input_json))[0]
    if args.output_path:
        out_file_base = os.path.abspath(args.output_path)
    out_file_csv = out_file_base + '_architecture_checker.csv'
    qnn_arch_checker = QnnArchChecker(ir_graph_ptr, constraints, logger, out_file_csv, args.input_json)
    logger.info("Running qnn-architecture-checker...")
    df = qnn_arch_checker.run_checks()

    if args.modify!=None:
        qnn_modifier = QnnModifier(ir_graph_ptr, constraints, logger, args.modify, df)
        check_modification = qnn_modifier.do_modification()
        if check_modification:
            out_file_json = out_file_base + '_architecture_checker.json'
            res = qnn_modifier.save_modified_graph(out_file_json)
            if not res:
                logger.error("Failed to apply modifications on this model. Unable to run Architecture checker on this model.")
                sys.exit(-1)
        else:
            logger.info("No rules found for applying modifications.")

    qnn_arch_checker.save_to_csv(df)
    logger.info("Saved output at: " + out_file_csv)

    html_file = os.path.splitext(os.path.abspath(args.input_json))[0] + "_architecture_checker.html"
    if args.output_path:
        html_file = os.path.abspath(args.output_path) + "_architecture_checker.html"
    qnn_arch_checker.get_html(html_file, args.input_json)
    logger.info("Saved html output at: " + html_file)

if __name__ == '__main__':
    main()
