# ==============================================================================
#
#  Copyright (c) Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# ==============================================================================
import os
import json
import onnx
import logging

from onnx import save_model
from onnxruntime_genai.models.builder import create_model
from onnxruntime.quantization.onnx_quantizer import ONNXModel
from . import update_onnx_graph_helper

logger = logging.getLogger(__name__)


class GraphBuilder:
    """
    Class to build GenAI model graph from the provided model and configuration.
    """
    def __init__(self, input_model: str, config_path: str, output_dir: str = None,
                 cache_dir: str = None, batch_size: int = 1, arn_value: int = 1,
                 filename_prefix: str = None):
        """
        Constructor
        :param input_model: Path to the input model file.
        :param config_path: Path to the directory containing the model configuration.
        :param output_dir: Path to the output directory for saving the onnx model.
        :param cache_dir: Path to the cache directory.
        :param batch_size: batch size for onnx graph.
        :param arn_value: Max number of autoregressive tokens at a time.
        :param filename_prefix: Prefix for the exported onnx graph.
        """
        self.input_model = os.path.abspath(input_model)
        self.output_dir = output_dir if output_dir else os.path.dirname(input_model)
        self.config_path = config_path
        self.cache_dir = cache_dir if cache_dir else self.output_dir
        self.batch_size = batch_size
        self.arn_value = arn_value
        self.filename_prefix = filename_prefix
        self.onnx_model_path = None

    def build_genai_model(self):
        """
        Build the GenAI model graph.
        """
        precision = "fp32"
        execution_provider = "cpu"

        logger.info("Generating ONNX graph")
        create_model(self.config_path, self.input_model, self.output_dir,
                     precision, execution_provider, self.cache_dir)
        self.onnx_model_path = os.path.join(self.output_dir, "model.onnx")

    def update_onnx_graph(self):
        """
        Apply different transformations to the onnx graph such as
        unpack packed (QKV) attention, decompose custom ops such as
        GroupQueryAttention.
        """
        onnx_model = onnx.load(self.onnx_model_path)

        # remove existing model
        if os.path.exists(self.onnx_model_path):
            os.remove(self.onnx_model_path)
        if os.path.exists(self.onnx_model_path + ".data"):
            os.remove(self.onnx_model_path + ".data")

        with open(os.path.join(self.output_dir, "config.json")) as config_file:
            model_config = json.load(config_file)

        update_onnx_graph_helper(onnx_model, model_config, self.arn_value, self.batch_size)
        ortmodel = ONNXModel(onnx_model)
        ortmodel.topological_sort()
        onnx_model = ortmodel.model

        onnx_filename = "model.onnx" if not self.filename_prefix else f"model_{self.filename_prefix}.onnx"
        self.onnx_model_path = os.path.join(os.path.dirname(self.onnx_model_path), onnx_filename)

        save_model(onnx_model, self.onnx_model_path, save_as_external_data=True, all_tensors_to_one_file=True,
                   location=os.path.basename(self.onnx_model_path + ".data"), size_threshold=0, convert_attribute=False)

        # Remove dequantized GGUF file
        if os.path.exists(self.input_model):
            os.remove(self.input_model)

        # Remove config generated by ORT GenAI
        genai_config_path = os.path.join(self.output_dir, "genai_config.json")
        if os.path.exists(genai_config_path):
            os.remove(genai_config_path)
