#!/usr/bin/env python3
# -*- mode: python -*-
#==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
#  All rights reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
#==============================================================================

from __future__ import print_function

import logging
import argparse
from argparse import RawTextHelpFormatter
import os
import sys
import traceback

try:
   from qti.aisw.dlc_utils import snpe_dlc_utils
   from qti.aisw.converters.common import json_serializer
except ImportError as ie:
   print("Failed to find necessary packages:")
   print(str(ie))
   print("Please ensure that $QAIRT_SDK_ROOT/lib/python is in your PYTHONPATH")
   sys.exit(1)

def main():
   try:
      parser = argparse.ArgumentParser(description="Reads in a DLC file and outputs a JSON file",
         formatter_class=RawTextHelpFormatter)
      parser._action_groups.pop()
      required = parser.add_argument_group('required arguments')
      required.add_argument('-i', '--input_dlc', metavar="\b", required=True, type=str,
                           help="Path to a DLC file.")
      optional = parser.add_argument_group('optional arguments')
      optional.add_argument('-o', '--output_json', metavar="\b", required=False, type=str,
                           help="Path to output Json file.")

      if len(sys.argv) == 1:
         parser.print_help(sys.stderr)
         sys.exit(1)

      args = parser.parse_args()

      snpe_dlc_utils.setUpLogger(True)

      logger = logging.getLogger()
      if not os.path.exists(args.input_dlc):
         logger.debug(f"Cannot find archive DLC file {args.input_dlc}")
         sys.exit(-1)

      json_file = args.output_json
      if json_file is None:
         # dump the Json file at path of input DLC
         dlc_abs_path = os.path.abspath(args.input_dlc)
         name, ext = os.path.splitext(dlc_abs_path)
         json_file = f"{name}_net.json"

      # DLC --> QNN-IR
      model_reader = snpe_dlc_utils.modeltools.IrDlcReader()
      model_reader.open(args.input_dlc)
      ir_graph = model_reader.get_ir_graph()

      # Serialize QNN-IR to Net JSON
      json_writer = json_serializer.IrJsonSerializer()
      json_writer.init_json_serializer("", "", "", "")
      json_writer.set_disable_sanitization(True)

      json_writer.serialize(ir_graph)
      ir_json = json_writer.get_graph_json()
      with open(json_file, "w") as js_file:
         js_file.write(ir_json)

      print(f"Model Json saved at: {json_file}")

   except Exception as e:
      traceback.print_exc()
      sys.exit(-2)

if __name__ == "__main__":
   main()