#!/usr/bin/env python3
# -*- mode: python -*-
#==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
#==============================================================================
from __future__ import print_function
import argparse
import logging
import os
import sys
import csv
from collections import OrderedDict
import numpy as np

from qti.aisw.converters.common.utils.converter_utils import log_error

try:
    from qti.aisw.dlc_utils import snpe_dlc_utils
    from qti.aisw.dlc_utils import snpe_diff_utils
except ImportError as ie:
    print("Failed to find necessary package:")
    print(str(ie))
    print("Please ensure that $SNPE_ROOT/lib/python is in your PYTHONPATH")
    sys.exit(1)

def main():
    try:
        # Setup and parse argument
        parser = argparse.ArgumentParser()
        required = parser.add_argument_group('required arguments')
        required.add_argument('-i1', '--input_dlc_one', required=True, type=str,
                              help="Path to the first DLC archive")
        required.add_argument('-i2', '--input_dlc_two', required=True, type=str,
                              help="Path to the second DLC archive")
        parser.add_argument("-c", "--copyrights", action="store_true", required=False,
                            help="Compare copyrights between models")
        parser.add_argument("-l", "--layers", action="store_true", required=False,
                            help="Compare unique layers between models")
        parser.add_argument("-p", "--parameters", action="store_true", required=False,
                            help="Compare parameter differences between identically named layers")
        parser.add_argument("-d", "--dimensions", action="store_true", required=False,
                            help="compare dimension differences between identically named layers")
        parser.add_argument("-w", "--weights", action="store_true", required=False,
                            help="compare weight differences between identically named layers.")
        parser.add_argument("-o", "--outputs", action="store_true", required=False,
                            help="compare output_tensor name differences names between identically named layers")
        parser.add_argument("-i", "--diff_by_id", action="store_true", required=False,
                            help="Overrides the default comparison strategy for diffing 2 models components. By "
                                 "default comparison is made between identically named layers. With this option the "
                                 "models are ordered by id and diff is done in order as long as no more than 1 "
                                 "consecutive layers have different layer types.")
        # parser.add_argument("-x", "--hta", action="store_true", required=False,
        #                    help="compare HTA records differences in Models")
        parser.add_argument("-s", "--save", required=False, type=str,
                            help="Save the output to a csv file. Specify a target file path.")

        args = parser.parse_args()

        csv_content = []
        if args.save is not None:
            snpe_dlc_utils.csv_file_flag = True

        snpe_dlc_utils.setUpLogger(True)

        logger = logging.getLogger()
        if not os.path.exists(args.input_dlc_one):
            logger.debug(f"Cannot find archive DLC file {args.input_dlc_one}")
            sys.exit(-1)
        if not os.path.exists(args.input_dlc_two):
            logger.debug(f"Cannot find archive DLC file {args.input_dlc_two}")
            sys.exit(-1)

        m1 = snpe_dlc_utils.ModelInfo(args.input_dlc_one)
        m2 = snpe_dlc_utils.ModelInfo(args.input_dlc_two)
        snpe_diff_utils.display_dlc_differences(m1, m2, csv_content, args, "snpe")

        if args.save is not None:
            try:
                with open(args.save, 'w') as csv_file:
                    writer = csv.writer(csv_file)
                    for d in csv_content:
                        writer.writerow(d)

            except IOError:
                print(f"IOError: Cannot open CSV file: {args.save}" , file=sys.stderr)
                sys.exit(-1)

    except Exception as e:
        print(e)


if __name__ == '__main__':
    main()
