#!/usr/bin/env python3
# -*- mode: python -*-
# =============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidaries
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import argparse
import json
import logging
import os
import webbrowser
import sys
import hashlib
import tempfile

if sys.version_info[0] == 3:
    from importlib import reload
# set encoding to utf8
reload(sys)
if sys.version_info[0] == 2:
    sys.setdefaultencoding('utf8')

try:
    from qti.aisw.dlc_utils import snpe_dlc_utils
except ImportError as ie:
    print("Failed to find necessary package:")
    print(str(ie))
    print("Please ensure that $SNPE_ROOT/lib/python is in your PYTHONPATH")
    sys.exit(1)

get_si_notation = snpe_dlc_utils.get_si_notation


def model_info_to_graph_info(rows_all, total_macs, total_params):
    """
    Assigns layer names as nodes (or vertices) in the model.
    Assigns connections between layer names and their input/output layers
    as links (or edges).
    """

    # List that will store dictionaries of layer names and their parameters
    nodes = []

    # List that will store directionaries of links between layers and their input/output layers
    links = []

    for row in rows_all:
        is_dlcv4 = hasattr(row, 'op')

        out_names_list = []
        for out_name in row.output_names:
            out_names_list.append(out_name)

        in_names_list = []

        if is_dlcv4:  # Is DLCV4
            for input in row.op.inputs():
                # Show APP_WRITE tensors as nodes in the displayed graph
                if input.is_app_write_tensor():
                    params = []
                    params.append('Data Type: ' + input.data_type_string())
                    params.append('<br/>Dimensions: ' + 'x'.join([str(x) for x in input.dims()]))

                    # Grab the network input's encoding information
                    encoding_prefix = input.name() + ' encoding:'  # TODO: find a way to avoid the magic string here
                    for parm in row.get_parm_list():
                        if parm.startswith(encoding_prefix):
                            params.append('<br/>' + parm)
                            break

                    nodes.append({'name': input.name(), 'parameters': params, 'type': 'Network Input', 'color': 0})
                    links.append({'source': input.name(), 'target': row.name})

                # Ignore static tensors (weights weren't shown in dlcv3-viewer)
                if not input.is_static_tensor():
                    in_names_list.append(input.name())
        else:
            for in_name in row.input_names:
                in_names_list.append(in_name)

        # Get parameter values for the layer
        m = max(len(row.get_parm_list()), len(row.get_input_list()))
        m = max(m,len(row.get_output_list()))
        parms = []

        # Get ID and output dimensions of layer as a string
        dims_len = len(row.output_dims_list[0])
        dims = ''
        if dims_len != 0:
            for index in range(dims_len - 1):
                dims = dims + str(row.output_dims_list[0][index]) + 'x'
            dims = dims + str(row.output_dims_list[0][dims_len - 1])
        parms.append('ID: ' + str(row.id))
        parms.append('<br/>Output Dims: ' + dims)
        parms.append('<br/>Layer Type:' + row.type)
        for i in range(0,m):
            if row.get_parm(i):
                parms.append('<br/>'+row.get_parm(i).replace('\n',''))

        # Get param_count per inference if it exists
        param_count = row.get_num_params()
        if param_count>0:
            parms.append('<br/>param count: '+ get_si_notation(param_count, total_params))

        # Get MACs per inference if it exists
        macs = row.get_macs()
        if macs>0:
            parms.append('<br/>MACs per inference: '+ get_si_notation(macs, total_macs))

        # Insert Affinity here
        if row.layer_affinity != 'UNSET':
            parms.append('<br/>Layer Affinity: '+ row.layer_affinity)

        # Create a node for the layer
        if is_dlcv4:
            # There are 16 defined colors in the template html, with the zeroth one being reserved for network inputs
            # For DLCV3, the first 10 are hardcoded to layer types in a messy way
            # We can hardcode the dlcv4 colors later, if desired, but right now they'll be assigned a color based
            # on a hash of the type
            colour = (int(hashlib.md5(row.type.encode()).hexdigest(),16) % 15) + 1
        else:
            colour = int(hashlib.md5(row.type.encode()).hexdigest(),16) % 5
            colour = colour + 10

        nodes.append({'name':row.name, 'parameters':parms, 'type': row.type, 'color':colour})
        for out_name in out_names_list:
            # To avoid self loop
            if out_name != row.name:
                links.append({"source":row.name, "target":out_name, 'dummy':row.name})
        for in_name in in_names_list:
            # To avoid self loop
            if in_name != row.name:
                links.append({"source":in_name, "target":row.name})



    return nodes, links

def main():
    parser = argparse.ArgumentParser()
    required = parser.add_argument_group('required arguments')
    required.add_argument('-i', '--input_dlc', type=str, required=True,
                          help="Path to a DLC archive")
    parser.add_argument("-s", "--save", type=str, required=False,
                        help="Save HTML file. Specify a file name and/or target save path")
    args = parser.parse_args()

    snpe_dlc_utils.setUpLogger(True)

    logger = logging.getLogger()
    if not os.path.exists(args.input_dlc):
        logger.error("Cannot find archive DLC file " + args.input_dlc)
        sys.exit(-1)

    # Load input dlc file and extract model information layer by layer
    model_info = snpe_dlc_utils.ModelInfo(args.input_dlc)
    # get meta-data from ModelInfo class
    (model_version, converter_command, quantizer_command, converter_version, model_copyright) = model_info.get_meta_data()
    input_file_path = os.path.abspath(args.input_dlc)

    # Set path to location of the script, rather than current working directory
    filepath = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
    # Set path of the shared directory

    if(os.path.exists(os.path.abspath(os.path.join(filepath, '..', '..','share', 'dlcviewer')))):
        sharedpath = os.path.abspath(os.path.join(filepath, '..', '..','share', 'dlcviewer'))
    elif(os.path.exists(os.path.abspath(os.path.join(filepath, '..','share', 'dlcviewer')))):
        sharedpath = os.path.abspath(os.path.join(filepath, '..','share', 'dlcviewer'))
    elif(os.path.exists(os.path.abspath(os.path.join(filepath, '..', '..','share','SNPE', 'dlcviewer')))):
        sharedpath = os.path.abspath(os.path.join(filepath, '..', '..','share','SNPE', 'dlcviewer'))
    else:
        logger.error("Failed to access the dependency packages in share location.")
        sys.exit(-1)

    # Read in dlc info viewer template
    with open(os.path.abspath(os.path.join(sharedpath, 'snpe_dlc_viewer_template.html')), 'r', encoding='utf8') as file :
        filepath_dlc = file.read()
    # Replace target strings with meta-data. Note: Only need values here since we have the
    # the keys for each of these already defined in the html template
    filepath_dlc = filepath_dlc.replace('??model??', '%s' % input_file_path)
    filepath_dlc = filepath_dlc.replace('??model_version??', '%s' % model_version.split(":")[1])
    filepath_dlc = filepath_dlc.replace('??converter_command??', '%s' % converter_command.split(":")[1])
    filepath_dlc = filepath_dlc.replace('??quantizer command??', '%s' % quantizer_command.split(":")[1])
    filepath_dlc = filepath_dlc.replace('??converter_version??', '%s' % converter_version.split(":")[1])

    # get model copyright info and create list from copyright string by
    # splitting it at every newline
    model_copyright = model_info.get_model_copyright().split('\n')
    model_copyright_table = ""

    for line in model_copyright:
        # here, class no_border so that we dont add lines in the copyright statement
        model_copyright_table += "<tr><td class='no_border'>" + line + "</td></tr>"

    # add to the html template
    filepath_dlc = filepath_dlc.replace('??copyright_table??', '%s' % model_copyright_table)

    # Retrieve name of DLC file
    modelname = os.path.splitext(os.path.basename(args.input_dlc))[0]

    if model_info.is_aix_enabled():
        # get aix info if it exists
        if model_info.is_aix_record_present():
            try:
                aix_records = model_info.get_aix_records()
                warning_msgs = ""
                headers = ["AIP Record Name", "nnc_version", "record_version", "hta_blob_id", "record_size", "Subnets"]
                aix_table = "<tr>"
                for header in headers:
                    aix_table += "<th class='info_headers'>" + header + "</th>"
                aix_table += "</tr>"
                for aix_record_name, aix_meta_info in aix_records.items():
                    aix_table += "<tr><td>" + aix_record_name + "</td>" # add the record name column

                    # add everything after name but before Subnets(since Subnets have further info)
                    for i in range(1, len(headers) - 1):
                        aix_table += "<td>" + str(aix_meta_info[headers[i]]) + "</td>"

                    subnet_col = "num_of_subnets: " + str(aix_meta_info['num_of_subnets'])
                    aix_table += "<td>" + subnet_col

                    # Add subnets meta info for record
                    if aix_meta_info['compatibility']:
                        for j in range(0, aix_meta_info['num_of_subnets']):
                            subnet_name = "subnet_" + str(j)
                            aix_table += "<p style='margin:0'>" + subnet_name + ':</p>'
                            # note: separated if cases for start/end ids so that they get printed one after the other for
                            #        better visual. Python was ordering them randomly even if OrderedDict was used.
                            if "start_layer_Id" in aix_meta_info[subnet_name].keys():
                                aix_table += "<p style='padding-left:15px; margin:0'> start_layer_Id: " \
                                             + str(aix_meta_info[subnet_name]["start_layer_Id"]) \
                                             + "</p>"
                                aix_meta_info[subnet_name].pop("start_layer_Id")
                            if "end_layer_Id" in aix_meta_info[subnet_name].keys():
                                aix_table += "<p style='padding-left:15px; margin:0'> end_layer_Id: " \
                                             + str(aix_meta_info[subnet_name]["end_layer_Id"]) \
                                             + "</p>"
                                aix_meta_info[subnet_name].pop("end_layer_Id")

                            for subnet_key, subnet_value in aix_meta_info[subnet_name].items():
                                if isinstance(subnet_value, list):
                                    aix_table += "<p style='padding-left:15px; margin:0'>" + str(subnet_key) + ":</p>"
                                    for value in subnet_value:
                                        aix_table += "<p style='padding-left:25px; margin:0'>" + str(value) + "</p>"
                                else:
                                    aix_table += "<p style='padding-left:15px; margin:0'>" + str(subnet_key) + ": " + str(subnet_value) + "</p>"
                        aix_table += "</td>"
                    else:
                        # add warning message if record is not compatible with current version of snpe
                        warning_msgs += "- Record " + aix_record_name + " is incompatible with the latest version of SNPE\n"

                if len(warning_msgs):
                    raise Exception(warning_msgs)

            except Exception as e:
                aix_table += "</table><p class='error'>Error querying AIP data:\n" + e.message + "</p>"
        else:
            aix_table = "<tr><td> No AIP Records Found in model. </td></tr>"

        # add to the html template
        filepath_dlc = filepath_dlc.replace('??aix_table??', '%s' % aix_table)
        filepath_dlc = filepath_dlc.replace('??is_use_aix_set??', '%s' % "block") # enable AIX info tab

    else:
        filepath_dlc = filepath_dlc.replace('??is_use_aix_set??', '%s' % "none")  # Disable displaying AIX info tab in HTML output

    graphs = []
    graph_num = 0
    for graph_name in model_info.graph_names:
        graph = []
        rows_all, total_params, total_macs, total_memory_cpu, total_memory_fxp, total_op_types, is_stripped = model_info.extract_model_info(graph_name)

        # Add method to ModelInfo class
        nodes, links = model_info_to_graph_info(rows_all, total_macs, total_params)
        graph.append(nodes)
        graph.append(links)
        graph.append(graph_name)
        graphs.append(graph)

        graph_num += 1

    # Replace target strings with nodes and links info
    filepath_dlc = filepath_dlc.replace('??graphs??', '%s' % json.dumps(graphs))
    filepath_dlc = filepath_dlc.replace('??graph_num??', '%s' % json.dumps(graph_num))

    html_out = ""
    if args.save:
        save_html = args.save
        if save_html.endswith('.html'):
            if '/' in save_html:
                # Fully qualified file-path
                html_out = os.path.join(os.path.abspath(save_html.rpartition('/')[0]), save_html.rpartition('/')[-1])
            else:
                # Only filename specified; append CWD path
                html_out = os.path.join(os.getcwd(), save_html)
        else:
            # Only path specified; append file name taken from model name
            html_out = os.path.join(os.path.abspath(save_html), '%s.html' %modelname)
    else:
        # No specifications provided; store file at a temporary location for rendering
        render_path = os.path.abspath(tempfile.gettempdir())
        html_out = os.path.join(render_path, '%s.html' %modelname)

    # Write out to an HTML file, specific to DLC name
    try:
      with open(html_out, 'w', encoding='utf8') as f:
          f.write(filepath_dlc)
    except IOError:
        logger.error("IOError: Cannot write HTML file " + html_out)
        sys.exit(-1)

    print("Network Model HTML file saved at %s" % html_out)

    # Open HTML in browser in try order
    webbrowser.open('file://'+os.path.abspath(html_out))

if __name__ == "__main__":
    main()
