#!/usr/bin/env python3
# -*- mode: python -*-
#==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
#==============================================================================
from __future__ import print_function
import argparse
from argparse import RawTextHelpFormatter
import logging
import os
import sys
import csv
from collections import OrderedDict
import numpy as np

from qti.aisw.converters.common.utils.converter_utils import log_error

try:
    from qti.aisw.dlc_utils import snpe_dlc_utils
    from qti.aisw.dlc_utils import snpe_diff_utils
except ImportError as ie:
    print("Failed to find necessary package:")
    print(str(ie))
    print("Please ensure that $SNPE_ROOT/lib/python is in your PYTHONPATH")
    sys.exit(1)

def main():
    try:
        parser = argparse.ArgumentParser(description="Reads in two DLC files and outputs differences between the models to stdout.",
                                         formatter_class=RawTextHelpFormatter)
        parser._action_groups.pop()
        required = parser.add_argument_group('required arguments')
        required.add_argument('-i1', '--input_dlc_one', metavar="\b", required=True, type=str,
                              help="Path to the first DLC archive.")
        required.add_argument('-i2', '--input_dlc_two', metavar="\b", required=True, type=str,
                              help="Path to the second DLC archive.")

        optional = parser.add_argument_group('optional arguments')
        optional.add_argument("-c", "--compare_copyrights", action="store_true", required=False,
                              help="Compare copyrights between models.")
        optional.add_argument("-l", "--compare_layers", action="store_true", required=False,
                              help="Compare unique layers between models.")
        optional.add_argument("-p", "--compare_parameters", action="store_true", required=False,
                              help="Compare parameter differences between identically named layers.")
        optional.add_argument("-d", "--compare_dimensions", action="store_true", required=False,
                              help="Compare dimension differences between identically named layers.")
        optional.add_argument("-w", "--compare_weights", action="store_true", required=False,
                              help="Compare weight differences between identically named layers.")
        optional.add_argument("-o", "--compare_outputs", action="store_true", required=False,
                              help="Compare output_tensor name differences names between identically named layers.")
        optional.add_argument("-i", "--enable_diff_by_id", action="store_true", required=False,
                              help="Overrides the default comparison strategy for diffing 2 models components. \n" +
                                   "By default comparison is made between identically named layers. \n" +
                                   "With this option the models are ordered by id and diff is done \n" +
                                   "in order as long as no more than 1 consecutive layers have different layer types.")
        # optional.add_argument("-x", "--compare_hta", action="store_true",
        #                       help="Compare HTA records differences in Models.")
        optional.add_argument("-s", "--save", metavar="\b", type=str, required=False,
                              help="Save the output to a csv file. Specify a target file path.")

        if len(sys.argv) == 1:
            parser.print_help(sys.stderr)
            sys.exit(1)

        args = parser.parse_args()

        csv_content = []
        if args.save is not None:
            snpe_dlc_utils.csv_file_flag = True

        snpe_dlc_utils.setUpLogger(True)

        logger = logging.getLogger()
        if not os.path.exists(args.input_dlc_one):
            logger.debug(f"Cannot find archive DLC file {args.input_dlc_one}")
            sys.exit(-1)
        if not os.path.exists(args.input_dlc_two):
            logger.debug(f"Cannot find archive DLC file {args.input_dlc_two}")
            sys.exit(-1)

        m1 = snpe_dlc_utils.ModelInfo(args.input_dlc_one)
        m2 = snpe_dlc_utils.ModelInfo(args.input_dlc_two)
        snpe_diff_utils.display_dlc_differences(m1, m2, csv_content, args, "qairt")

        if args.save is not None:
            try:
                with open(args.save, 'w') as csv_file:
                    writer = csv.writer(csv_file)
                    for d in csv_content:
                        writer.writerow(d)

            except IOError:
                print(f"IOError: Cannot open CSV file {args.save}", file=sys.stderr)
                sys.exit(-1)

    except Exception as e:
        print(e)


if __name__ == '__main__':
    main()
