//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#include <exception>
#include <set>
#include <sstream>

#include "Dialog.hpp"
#include "Exception.hpp"
#include "Macro.hpp"
#include "qualla/detail/json.hpp"
#include "qualla/env.hpp"

using namespace genie;

#ifdef _WIN32
inline std::string libPrefix = "";
inline std::string libSuffix = ".dll";
#else
inline std::string libPrefix = "lib";
inline std::string libSuffix = ".so";
#endif

inline std::string getLibName(std::string baseName) { return libPrefix + baseName + libSuffix; }

//=============================================================================
// Context::Config functions
//=============================================================================

static void validateContextConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "context config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "bos-token", "eos-token", "size", "n-vocab"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing context field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "context";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid context config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "bos-token") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "eos-token") {
      JSON_ENFORCE_ARRAY_OR_NUMERIC();
    } else if (item.key() == "eot-token") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "size") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "n-vocab") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "pad-token") {
      JSON_ENFORCE_NUMERIC();
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown context config key: " + item.key());
    }
  }
}

static void translateContextConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
  if (genieConfig["dialog"].contains("context")) {
    if (genieConfig["dialog"]["context"].contains("bos-token")) {
      quallaConfig["context"]["bos-token"] = genieConfig["dialog"]["context"]["bos-token"];
    }
    if (genieConfig["dialog"]["context"].contains("eos-token")) {
      quallaConfig["context"]["eos-token"] = genieConfig["dialog"]["context"]["eos-token"];
    }
    if (genieConfig["dialog"]["context"].contains("eot-token")) {
      quallaConfig["context"]["eot-token"] = genieConfig["dialog"]["context"]["eot-token"];
    }
    if (genieConfig["dialog"]["context"].contains("size")) {
      quallaConfig["context"]["size"] = genieConfig["dialog"]["context"]["size"];
    }
    if (genieConfig["dialog"]["context"].contains("n-vocab")) {
      quallaConfig["context"]["n-vocab"] = genieConfig["dialog"]["context"]["n-vocab"];
    }
    if (genieConfig["dialog"]["context"].contains("pad-token")) {
      quallaConfig["context"]["pad-token"] = genieConfig["dialog"]["context"]["pad-token"];
    }
  }
}

//=============================================================================
// Tokenizer::Config functions
//=============================================================================

static void validateTokenizerConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "tokenizer config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "path"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing tokenizer field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "tokenizer";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid tokenizer config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "path") {
      JSON_ENFORCE_STRING();
      // Note: the existence of this file is checked by qualla
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "Unknown tokenizer config key: " + item.key());
    }
  }
}

static void translateTokenizerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
  quallaConfig["tokenizer"] = genieConfig["dialog"]["tokenizer"]["path"];
}

//=============================================================================
// Embedding::Config functions
//=============================================================================

static void validateEmbeddingConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "embedding config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "size"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing embedding field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "embedding";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid embedding config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "size") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "datatype") {
      JSON_ENFORCE_STRING();
      const std::set<std::string> supportedTypes = {"float32", "native"};
      if (std::find(supportedTypes.begin(), supportedTypes.end(), item.value()) ==
          supportedTypes.end()) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Unknown embedding datatype: " + std::string(item.value()));
      }
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "Unknown embedding config key: " + item.key());
    }
  }
}

static void translateEmbeddingConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
  if (genieConfig["dialog"].contains("embedding")) {
    quallaConfig["context"]["n-embd"] = genieConfig["dialog"]["embedding"]["size"];

    if (genieConfig["dialog"]["embedding"].contains("datatype")) {
      quallaConfig["context"]["embedding-datatype"] =
          genieConfig["dialog"]["embedding"]["datatype"];
    }
  }
}

bool position_dim_set = false;
bool rope_theta_set   = false;

//=============================================================================
// Backend::Config functions
//=============================================================================

static void validateBackendHtpConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnHtp config is not an object");
  }

  std::set<std::string> mandatoryFields{
      "version", "spill-fill-bufsize", "mmap-budget", "use-mmap", "cpu-mask", "poll"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "QnnHtp";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid QnnHtp config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "spill-fill-bufsize") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "mmap-budget") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "use-mmap") {
      JSON_ENFORCE_BOOLEAN();
#ifdef _WIN32
      if (item.value() == true) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid QnnHtp config. use-mmap not supported on target");
      }
#endif
    } else if (item.key() == "pos-id-dim") {
      position_dim_set = true;
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "cpu-mask") {
      JSON_ENFORCE_STRING();
    } else if (item.key() == "poll") {
      JSON_ENFORCE_BOOLEAN();
    } else if (item.key() == "kv-dim") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "kv-update-method") {
      JSON_ENFORCE_STRING();
    } else if (item.key() == "allow-async-init") {
      JSON_ENFORCE_BOOLEAN();
    } else if (item.key() == "rope-theta") {
      rope_theta_set = true;
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "enable-graph-switching") {
      JSON_ENFORCE_BOOLEAN();
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown QnnHtp config key: " + item.key());
    }
  }
}

static void validateBackendGenaiConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnGenAiTransformer config is not an object");
  }

  std::set<std::string> mandatoryFields{"version"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "Missing QnnGenAiTransformer field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "QnnGenAiTransformer";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(
            GENIE_STATUS_ERROR_JSON_VALUE,
            "Invalid QnnGenAiTransformer config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "use-mmap") {
      JSON_ENFORCE_BOOLEAN();
#ifdef _WIN32
      if (item.value() == true) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid QnnGenAiTransformer config. use-mmap not supported on target");
      }
#endif
    } else if (item.key() == "n-logits") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "n-layer") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "n-embd") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "n-heads") {
      JSON_ENFORCE_NUMERIC();
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "Unknown QnnGenAiTransformer config key: " + item.key());
    }
  }
}

static void validateBackendConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "backend config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "type"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing backend field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "backend";

  std::string type;
  bool htp = false;
  qualla::json htpConfig;
  bool genai = false;
  qualla::json genaiConfig;

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid backend config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "type") {
      JSON_ENFORCE_STRING();
      type = item.value().get<std::string>();
      if (type == "QnnHtp") {
        htp = true;
      } else if (type == "QnnGenAiTransformer") {
        genai = true;
      } else if (type != "QnnGpu") {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid backend config: unsupported type: " + item.value().dump());
      }
    } else if (item.key() == "extensions") {
      JSON_ENFORCE_STRING();
    } else if (item.key() == "QnnHtp") {
      JSON_ENFORCE_OBJECT();
      htpConfig = item.value();
    } else if (item.key() == "QnnGenAiTransformer") {
      JSON_ENFORCE_OBJECT();
      genaiConfig = item.value();
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown backend config key: " + item.key());
    }
  }

  if (htp) {
    if (!htpConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp dialog config");
    }
    validateBackendHtpConfig(htpConfig);
  } else {
    if (htpConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "QnnHtp backend config for incorrect backend type: " + type);
    }
  }

  if (genai) {
    if (!genaiConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnGenAiTransformer dialog config");
    }
    validateBackendGenaiConfig(genaiConfig);
  } else {
    if (genaiConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "QnnGenAiTransformer backend config for incorrect backend type: " + type);
    }
  }
}

//=============================================================================
// Model::Config functions
//=============================================================================

static void validateLoraAdapterConfig(const qualla::json& config,
                                      LORA_VERSION& specifiedLoraVersion) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lora adapter config is not an object");
  }
  const std::set<std::string> mandatoryFields{"version", "name"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lora adapter field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  const std::string component        = "lora adapter";
  LORA_VERSION configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED;
  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid lora config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "name") {
      JSON_ENFORCE_STRING();
    } else if (item.key() == "bin-sections") {
      JSON_ENFORCE_ARRAY();
      configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2;  // Adapter occurs with V2
      for (auto& elem : item.value()) {
        if (!elem.is_string()) {
          throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                          "bin-sections must be an array of strings");
        }
      }
    } else if (item.key() == "path") {
      configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V1;  // Weights are V1
      JSON_ENFORCE_STRING();
      // Note:all directory validations will done by NSP engine
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "Unknown lora adapter config key: " + item.key());
    }
  }

  if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V1 &&
      configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V2) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                    "LoRA Adapters must be used with lora version: 2");
  } else if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V2 &&
             configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V1) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                    "LoRA Weights must be used with lora version: 1");
  } else if (configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Invalid lora config.");
  }
}

static void validateLoraConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lora config is not an object");
  }

  const std::set<std::string> mandatoryFields{"version", "adapters"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lora field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  const std::string component       = "lora";
  LORA_VERSION specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2;  // Default is loraV2
  if (config.find("lora-version") != config.end()) {
    switch (static_cast<uint8_t>(config["lora-version"])) {
      case 1:
        specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V1;
        break;
      case 2:
        specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2;
        break;
      default:
        specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED;
        break;
    }
  }

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid lora config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "alpha-tensor-name") {
      JSON_ENFORCE_STRING();
    } else if (item.key() == "adapters") {
      JSON_ENFORCE_ARRAY();
      for (auto& elem : item.value()) {
        validateLoraAdapterConfig(elem, specifiedLoraVersion);
      }
    } else if (item.key() == "lora-version") {  // Optional
      JSON_ENFORCE_NUMERIC();
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown lora config key: " + item.key());
    }
  }
  if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                    "Unsupported lora version: " + to_string(config["lora-version"]));
  }
}

static void validateModelBinaryConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "binary config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "ctx-bins"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "binary";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid binary config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "ctx-bins") {
      JSON_ENFORCE_ARRAY();
      for (auto& elem : item.value()) {
        if (!elem.is_string()) {
          throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "ctx-bins must be an array of strings");
        }
      }
    } else if (item.key() == "lora") {
      JSON_ENFORCE_OBJECT();
      validateLoraConfig(item.value());
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown binary config key: " + item.key());
    }
  }
}

static void validateModelLibraryConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "library config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "model-bin"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "library";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid library config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "model-bin") {
      JSON_ENFORCE_STRING();
    } else if (item.key() == "lora") {
      JSON_ENFORCE_OBJECT();
      validateLoraConfig(item.value());
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown library config key: " + item.key());
    }
  }
}

static void validateRopeScalingConfig(const qualla::json& config) {
  // component is used in the "ENFORCE" macros
  std::string component = "rope-scaling";
  if (config.is_object()) {
    std::string ropeType;
    for (auto& item : config.items()) {
      if (item.key() == "rope-type") {
        JSON_ENFORCE_STRING();
        ropeType = item.value().get<std::string>();
        if (ropeType != "llama3" && ropeType != "default" && ropeType != "longrope") {
          throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Rope type not supported" + ropeType);
        }
      } else if (item.key() == "factor" || item.key() == "low-freq-factor" ||
                 item.key() == "high-freq-factor" ||
                 item.key() == "original-max-position-embeddings") {
        JSON_ENFORCE_NUMERIC();
      } else if (item.key() == "short-factor" || item.key() == "long-factor") {
        JSON_ENFORCE_ARRAY();
      } else {
        throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                        "Rope scaling parameter not supported " + item.key());
      }
    }
  }
}

static void validatePositionalEncodingConfig(const qualla::json& config) {
  // component is used in the "ENFORCE" macros
  std::string component = "positional-encoding";
  qualla::json ropeScalingConfig;
  if (config.is_object()) {
    for (auto& item : config.items()) {
      if (item.key() == "type") {
        std::string positionEncodingType = item.value().get<std::string>();
        if (positionEncodingType != "rope" && positionEncodingType != "absolute" &&
            positionEncodingType != "alibi") {
          throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "positional-encoding type not supported");
        }
      } else if (item.key() == "rope-dim") {
        JSON_ENFORCE_NUMERIC();
      } else if (item.key() == "rope-theta") {
        JSON_ENFORCE_NUMERIC();
      } else if (item.key() == "rope-scaling") {
        JSON_ENFORCE_OBJECT();
        ropeScalingConfig = item.value();
      } else {
        throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                        "Unknown positional encoding config key: " + item.key());
      }
    }
  }
  if (position_dim_set) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                    "Specify one config from pos-id-dim and positional-encoding");
  }
  if (rope_theta_set) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                    "Specify one config from rope-theta and positional-encoding");
  }
  if (ropeScalingConfig.is_object()) {
    validateRopeScalingConfig(ropeScalingConfig);
  }
}

static void validateModelConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "model config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "type"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing model field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "model";

  std::string type;
  bool binary = false;
  qualla::json binaryConfig;
  bool library = false;
  qualla::json libraryConfig;
  qualla::json positionalEncodingConfig;
  bool positionalEncoding = false;

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid model config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "type") {
      JSON_ENFORCE_STRING();
      type = item.value().get<std::string>();
      if (type == "binary") {
        binary = true;
      } else if (type == "library") {
        library = true;
      } else {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid model config: unsupported type: " + item.value().dump());
      }
    } else if (item.key() == "binary") {
      JSON_ENFORCE_OBJECT();
      binaryConfig = item.value();
    } else if (item.key() == "library") {
      JSON_ENFORCE_OBJECT();
      libraryConfig = item.value();
    } else if (item.key() == "positional-encoding") {
      JSON_ENFORCE_OBJECT();
      positionalEncodingConfig = item.value();
      positionalEncoding       = true;
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown model config key: " + item.key());
    }
  }

  if (binary) {
    if (!binaryConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary model config");
    }
    validateModelBinaryConfig(binaryConfig);
  } else {
    if (binaryConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "binary model config for incorrect model type: " + type);
    }
  }

  if (library) {
    if (!libraryConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library model config");
    }
    validateModelLibraryConfig(libraryConfig);
  } else {
    if (libraryConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "library model config for incorrect model type: " + type);
    }
  }

  if (positionalEncoding) {
    if (!positionalEncodingConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing Positional encoding config");
    }
    validatePositionalEncodingConfig(positionalEncodingConfig);
  } else {
    if (positionalEncodingConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "Positional encoding config for incorrect model type: " + type);
    }
  }
}

//=============================================================================
// Engine::Config functions
//=============================================================================

static void validateEngineConfig(const qualla::json& config, std::string dialogType) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "backend", "model", "n-threads"};
#if defined(GENIE_SPD_FEATURE)
  if (dialogType == "spd") {
    mandatoryFields.insert("role");
  }
#endif
  if (dialogType == "kv-share") {
    mandatoryFields.insert("role");
  }

  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing engine field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "engine";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid engine config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "backend") {
      JSON_ENFORCE_OBJECT();
      validateBackendConfig(item.value());
    } else if (item.key() == "model") {
      JSON_ENFORCE_OBJECT();
      validateModelConfig(item.value());
    } else if (item.key() == "n-threads") {
      JSON_ENFORCE_NUMERIC();
#if defined(GENIE_SPD_FEATURE)
    } else if (item.key() == "role" && dialogType == "spd") {
      JSON_ENFORCE_STRING();
      if (item.value() != "draft" && item.value() != "target") {
        throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                        "Unknown value: for engine config key: " + item.key());
      }
#endif
    } else if (item.key() == "role" && dialogType == "kv-share") {
      JSON_ENFORCE_STRING();
      if (item.value() != "primary" && item.value() != "secondary") {
        throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                        "Unknown value: for engine config key: " + item.key());
      }
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown engine config key: " + item.key());
    }
  }
}

static void validateMultiEngineConfig(const qualla::json& configs, std::string dialogType) {
  if (configs.is_object()) {
    validateEngineConfig(configs, dialogType);
#if defined(GENIE_SPD_FEATURE)
    if (dialogType == "spd") {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config for spd is not an array");
    }
#endif
    if (dialogType == "kv-share") {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config for kv-share is not an array");
    }
#if defined(GENIE_SPD_FEATURE)
  } else if (configs.is_array() && dialogType == "spd") {
    if (configs.size() != 2) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "engine config for spd contain invalid number of engines");
    }
    bool engineRoleMask[2] = {false, false};
    for (auto& item : configs) {
      validateEngineConfig(item, dialogType);
      if (item["role"] == "draft") {
        engineRoleMask[0] = true;
      } else if (item["role"] == "target") {
        engineRoleMask[1] = true;
      }
    }
    if (!engineRoleMask[0]) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "engine config for spd does not contain draft engine");
    }
    if (!engineRoleMask[1]) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "engine config for spd does not contain target engine");
    }
#endif
  } else if (configs.is_array() && dialogType == "kv-share") {
    if (configs.size() != 2) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "engine config for kv-share contain invalid number of engines");
    }
    bool engineRoleMask[2] = {false, false};
    for (auto& item : configs) {
      validateEngineConfig(item, dialogType);
      if (item["role"] == "primary") {
        engineRoleMask[0] = true;
      } else if (item["role"] == "secondary") {
        engineRoleMask[1] = true;
      }
    }
    if (!engineRoleMask[0]) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "engine config for kv-share does not contain primary");
    }
    if (!engineRoleMask[1]) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "engine config for kv-share does not contain secondary");
    }
  } else {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config is not an object or an array");
  }
}

static void translateEngineConfig(const qualla::json& genieEngineConfig,
                                  qualla::json& quallaEngineConfig) {
  if (genieEngineConfig["version"] == 1) {
    if (genieEngineConfig.contains("role")) {
      quallaEngineConfig["role"] = genieEngineConfig["role"];
    } else {
      quallaEngineConfig["role"] = "primary";
    }

    quallaEngineConfig["n-threads"] = genieEngineConfig["n-threads"];

    if (genieEngineConfig["backend"]["type"] == "QnnHtp") {
      quallaEngineConfig["type"]        = "qnn-htp";
      quallaEngineConfig["backend-lib"] = getLibName("QnnHtp");
      quallaEngineConfig["mmap-budget"] = genieEngineConfig["backend"]["QnnHtp"]["mmap-budget"];
      quallaEngineConfig["use-mmap"]    = genieEngineConfig["backend"]["QnnHtp"]["use-mmap"];
      quallaEngineConfig["spill-fill-bufsize"] =
          genieEngineConfig["backend"]["QnnHtp"]["spill-fill-bufsize"];
      if (genieEngineConfig["backend"]["QnnHtp"].contains("pos-id-dim")) {
        quallaEngineConfig["pos-id-dim"] = genieEngineConfig["backend"]["QnnHtp"]["pos-id-dim"];
      }
      quallaEngineConfig["cpumask"] = genieEngineConfig["backend"]["QnnHtp"]["cpu-mask"];
      quallaEngineConfig["poll"]    = genieEngineConfig["backend"]["QnnHtp"]["poll"];
      quallaEngineConfig["kv-dim"]  = genieEngineConfig["backend"]["QnnHtp"]["kv-dim"];
      if (genieEngineConfig["backend"]["QnnHtp"].contains("rope-theta")) {
        quallaEngineConfig["rope-theta"] = genieEngineConfig["backend"]["QnnHtp"]["rope-theta"];
      }
      if (genieEngineConfig["backend"]["QnnHtp"].contains("kv-update-method")) {
        quallaEngineConfig["kv-update-method"] =
            genieEngineConfig["backend"]["QnnHtp"]["kv-update-method"];
      }
      // By default, Qualla will default to the async init path.
      // For now, we are forcing async init off unless explicitly
      // specified in the Genie config. It is HTP specific feature only.
      quallaEngineConfig["use-async-Init"] = false;
      if (genieEngineConfig["backend"]["QnnHtp"].contains("allow-async-init")) {
        quallaEngineConfig["use-async-Init"] =
            genieEngineConfig["backend"]["QnnHtp"]["allow-async-init"];
      }
      if (genieEngineConfig["backend"]["QnnHtp"].contains("enable-graph-switching")) {
        quallaEngineConfig["enable-graph-switching"] =
            genieEngineConfig["backend"]["QnnHtp"]["enable-graph-switching"];
      }
    } else if (genieEngineConfig["backend"]["type"] == "QnnGenAiTransformer") {
      quallaEngineConfig["type"]        = "qnn-cpu";
      quallaEngineConfig["backend-lib"] = getLibName("QnnGenAiTransformer");
      if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-logits")) {
        quallaEngineConfig["n_logits"] =
            genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-logits"];
      }
      if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("use-mmap")) {
        quallaEngineConfig["use-mmap"] =
            genieEngineConfig["backend"]["QnnGenAiTransformer"]["use-mmap"];
      }
      if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-layer")) {
        quallaEngineConfig["n_layer"] =
            genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-layer"];
      }
      if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-embd")) {
        quallaEngineConfig["n_embd"] =
            genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-embd"];
      }
      if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-heads")) {
        quallaEngineConfig["n_heads"] =
            genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-heads"];
      }
    } else if (genieEngineConfig["backend"]["type"] == "QnnGpu") {
      quallaEngineConfig["type"] = "qnn-gpu";
    }

    if (genieEngineConfig["backend"].contains("extensions")) {
      quallaEngineConfig["backend-ext-conf"] = genieEngineConfig["backend"]["extensions"];
    }

    if (genieEngineConfig["model"]["type"] == "binary") {
      quallaEngineConfig["model-list"] = genieEngineConfig["model"]["binary"]["ctx-bins"];
      if (genieEngineConfig["model"]["binary"].contains("lora")) {
        quallaEngineConfig["lora-version"] =
            static_cast<uint8_t>(LORA_VERSION::GENIE_LORA_VERSION_V2);
        if (genieEngineConfig["model"]["binary"]["lora"].contains("lora-version") &&
            genieEngineConfig["model"]["binary"]["lora"]["lora-version"] == 1) {
          quallaEngineConfig["lora-version"] =
              genieEngineConfig["model"]["binary"]["lora"]["lora-version"];
        }
        for (int i = 0; i < genieEngineConfig["model"]["binary"]["lora"]["adapters"].size(); i++) {
          quallaEngineConfig["lora"][i]["adapter-name"] =
              genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["name"];
          quallaEngineConfig["lora"][i]["alpha-tensor-name"] = "";
          if (genieEngineConfig["model"]["binary"]["lora"].contains("alpha-tensor-name")) {
            quallaEngineConfig["lora"][i]["alpha-tensor-name"] =
                genieEngineConfig["model"]["binary"]["lora"]["alpha-tensor-name"];
          }
          quallaEngineConfig["lora"][i]["alpha-tensor-value"] = 1.0f;
          quallaEngineConfig["lora"][i]["binsection-basedir"] = "";
          if (genieEngineConfig["model"]["binary"]["lora"].contains("lora-version") &&
              genieEngineConfig["model"]["binary"]["lora"]["lora-version"] == 1) {
            quallaEngineConfig["lora"][i]["path"] =
                genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["path"];
          } else {
            quallaEngineConfig["lora"][i]["bin-sections"] =
                genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["bin-sections"];
          }
        }
      }
    } else if (genieEngineConfig["model"]["type"] == "library") {
      quallaEngineConfig["model"]          = getLibName("QnnGenAiTransformerModel");
      quallaEngineConfig["model-bin-path"] = genieEngineConfig["model"]["library"]["model-bin"];
      quallaEngineConfig["op-package"] =
          getLibName("QnnGenAiTransformerCpuOpPkg") + ":QnnOpPackage_interfaceProvider";
      if (genieEngineConfig["model"]["library"].contains("lora")) {
        for (int i = 0; i < genieEngineConfig["model"]["library"]["lora"]["adapters"].size(); i++) {
          quallaEngineConfig["lora"][i]["adapter-name"] =
              genieEngineConfig["model"]["library"]["lora"]["adapters"][i]["name"];
          if (genieEngineConfig["model"]["library"]["lora"].contains("alpha-tensor-name")) {
            quallaEngineConfig["lora"][i]["alpha-tensor-name"] =
                genieEngineConfig["model"]["library"]["lora"]
                                 ["alpha-tensor-name"];
          }
          quallaEngineConfig["lora"][i]["alpha-tensor-value"] = 1.0f;
          quallaEngineConfig["lora"][i]["binsection-basedir"] = "";
          quallaEngineConfig["lora"][i]["bin-sections"] =
              genieEngineConfig["model"]["library"]["lora"]["adapters"][i]["bin-sections"];
        }
      }
    }
    if (genieEngineConfig["model"].contains("positional-encoding")) {
      quallaEngineConfig["positional-encoding"]["type"] =
          genieEngineConfig["model"]["positional-encoding"]["type"];
      if (genieEngineConfig["model"]["positional-encoding"]["type"] == "rope") {
        quallaEngineConfig["positional-encoding"]["rope-dim"] =
            genieEngineConfig["model"]["positional-encoding"]["rope-dim"];
        if (genieEngineConfig["model"]["positional-encoding"].contains("rope-theta")) {
          quallaEngineConfig["positional-encoding"]["rope-theta"] =
              genieEngineConfig["model"]["positional-encoding"]["rope-theta"];
        }
        if (genieEngineConfig["model"]["positional-encoding"].contains("rope-scaling")) {
          if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
                  "rope-type")) {
            quallaEngineConfig["positional-encoding"]["rope-scaling"]["rope-type"] =
                genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"];
            if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"] ==
                "llama3") {
              if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
                      "factor")) {
                quallaEngineConfig["positional-encoding"]["rope-scaling"]["factor"] =
                    genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["factor"];
              }
              if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
                      "low-freq-factor")) {
                quallaEngineConfig["positional-encoding"]["rope-scaling"]["low-freq-factor"] =
                    genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
                                     ["low-freq-factor"];
              }
              if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
                      "high-freq-factor")) {
                quallaEngineConfig["positional-encoding"]["rope-scaling"]["high-freq-factor"] =
                    genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
                                     ["high-freq-factor"];
              }
              if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
                      "original-max-position-embeddings")) {
                quallaEngineConfig["positional-encoding"]["rope-scaling"]
                                  ["original-max-position-embeddings"] =
                                      genieEngineConfig["model"]["positional-encoding"]
                                                       ["rope-scaling"]
                                                       ["original-max-position-embeddings"];
              }
            }
            if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"] ==
                "longrope") {
              if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
                      "factor")) {
                quallaEngineConfig["positional-encoding"]["rope-scaling"]["factor"] =
                    genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["factor"];
              }
              if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
                      "short-factor")) {
                quallaEngineConfig["positional-encoding"]["rope-scaling"]["short-factor"] =
                    genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
                                     ["short-factor"];
              }
              if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
                      "long-factor")) {
                quallaEngineConfig["positional-encoding"]["rope-scaling"]["long-factor"] =
                    genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
                                     ["long-factor"];
              }
              if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
                      "original-max-position-embeddings")) {
                quallaEngineConfig["positional-encoding"]["rope-scaling"]
                                  ["original-max-position-embeddings"] =
                                      genieEngineConfig["model"]["positional-encoding"]
                                                       ["rope-scaling"]
                                                       ["original-max-position-embeddings"];
              }
            }
          }
        }
      }
    }
  }
}

static void translateMultiEngineConfig(const qualla::json& genieConfig,
                                       qualla::json& quallaConfig) {
  if (genieConfig["dialog"]["engine"].is_array()) {
    quallaConfig["engine"] = qualla::json::array();
    for (auto& item : genieConfig["dialog"]["engine"]) {
      qualla::json quallaEngineConfig;
      translateEngineConfig(item, quallaEngineConfig);
      quallaConfig["engine"].push_back(quallaEngineConfig);
    }
  } else {
    translateEngineConfig(genieConfig["dialog"]["engine"], quallaConfig["engine"]);
  }
}

//=============================================================================
// Dialog::Config functions
//=============================================================================

qnn::util::HandleManager<Dialog::Config> Dialog::Config::s_manager;

GenieDialogConfig_Handle_t Dialog::Config::add(std::shared_ptr<Dialog::Config> config) {
  return (GenieDialogConfig_Handle_t)s_manager.add(config);
}

std::shared_ptr<Dialog::Config> Dialog::Config::get(GenieDialogConfig_Handle_t handle) {
  return s_manager.get((qnn::util::Handle_t)handle);
}

void Dialog::Config::remove(GenieDialogConfig_Handle_t handle) {
  s_manager.remove((qnn::util::Handle_t)handle);
}

#if defined(GENIE_SSD_FEATURE)
static void validateDialogSsdConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "ssd-q1 config is not an object");
  }

  std::set<std::string> mandatoryFields{"version",
                                        "ssd-version",
                                        "forecast-token-count",
                                        "branches",
                                        "forecast-prefix",
                                        "forecast-prefix-name"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing ssd-q1 field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "ssd-q1";

  int branchesSize       = 0;
  int forecastTokenCount = 0;

  int nStreams     = 1;
  float pThreshold = 0.0;

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid ssd-q1 config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "ssd-version") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "forecast-token-count") {
      JSON_ENFORCE_NUMERIC();
      forecastTokenCount = item.value();
    } else if (item.key() == "branches") {
      JSON_ENFORCE_ARRAY();
      for (auto& elem : item.value()) {
        if (!elem.is_number_integer()) {
          throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "branches must be an array of integers");
        }
      }
      branchesSize = item.value().size();
    } else if (item.key() == "forecast-prefix") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "forecast-prefix-name") {
      JSON_ENFORCE_STRING();
    } else if (item.key() == "n-streams") {
      JSON_ENFORCE_NUMERIC();
      nStreams = item.value();
    } else if (item.key() == "p-threshold") {
      JSON_ENFORCE_NUMERIC();
      pThreshold = item.value();
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown ssd-q1 config key: " + item.key());
    }
  }

  if ((pThreshold > 0.0) && (nStreams <= 1)) {
    throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                    "p-threshold can only be used with multistream (n-streams > 1)");
  }

  if (branchesSize > forecastTokenCount) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                    "Size of branches array must be less than forecast-token-count");
  }
}
#endif

#if defined(GENIE_LADE_FEATURE)
static void validateDialogLadeConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lade config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "update-mode", "window", "ngram", "gcap"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lade field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "lade";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid lade config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "update-mode") {
      JSON_ENFORCE_STRING();
      std::string mode = item.value().get<std::string>();
      if ((mode != "FWD_MAX_HIT") && (mode != "FWD_LEVEL") && (mode != "ALWAYS_FWD_ONE")) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid lade config: unsupported update-mode: " + item.value().dump());
      }
    } else if (item.key() == "window") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "ngram") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "gcap") {
      JSON_ENFORCE_NUMERIC();
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown lade config key: " + item.key());
    }
  }
}
#endif

#if defined(GENIE_SPD_FEATURE)
static void validateDialogSpdConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "spd config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "draft-len"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing spd field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "spd";
  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid spd config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "draft-len") {
      JSON_ENFORCE_NUMERIC();
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown spd config key: " + item.key());
    }
  }
}
#endif

#if defined(GENIE_MULTISTREAM_FEATURE)
static void validateDialogMultistreamConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "multistream config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "n-streams"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing multistream field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "multistream";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid multistream config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "n-streams") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "p-threshold") {
      JSON_ENFORCE_NUMERIC();
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "Unknown multistream config key: " + item.key());
    }
  }
}
#endif

static void validateDialogConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Dialog config is not an object");
  }

  std::set<std::string> mandatoryFields{"version", "type", "context", "tokenizer", "engine"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing dialog field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "dialog";

  std::string dialogType = "basic";
#if defined(GENIE_SSD_FEATURE)
  bool ssdq1 = false;
  qualla::json ssdq1Config;
#endif
#if defined(GENIE_LADE_FEATURE)
  bool lade = false;
  qualla::json ladeConfig;
#endif
#if defined(GENIE_SPD_FEATURE)
  bool spd = false;
  qualla::json spdConfig;
#endif
#if defined(GENIE_MULTISTREAM_FEATURE)
  bool multistream = false;
  qualla::json multistreamConfig;
#endif

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid dialog config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "type") {
      JSON_ENFORCE_STRING();
      dialogType = item.value();
      if (dialogType == "basic" || dialogType == "kv-share") {
        // Do nothing
#if defined(GENIE_SSD_FEATURE)
      } else if (dialogType == "ssd-q1") {
        ssdq1 = true;
#endif
#if defined(GENIE_LADE_FEATURE)
      } else if (dialogType == "lade") {
        lade = true;
#endif
#if defined(GENIE_SPD_FEATURE)
      } else if (dialogType == "spd") {
        spd = true;
#endif
#if defined(GENIE_MULTISTREAM_FEATURE)
      } else if (dialogType == "multistream") {
        multistream = true;
#endif
      } else {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "Invalid dialog type: " + dialogType);
      }
#if defined(GENIE_SSD_FEATURE)
    } else if (item.key() == "ssd-q1") {
      JSON_ENFORCE_OBJECT();
      ssdq1Config = item.value();
      // ssd-q1 validation is done below
#endif
#if defined(GENIE_LADE_FEATURE)
    } else if (item.key() == "lade") {
      JSON_ENFORCE_OBJECT();
      ladeConfig = item.value();
      // ssd-q1 validation is done below
#endif
#if defined(GENIE_SPD_FEATURE)
    } else if (item.key() == "spd") {
      JSON_ENFORCE_OBJECT();
      spdConfig = item.value();
      // spd validation is done below
#endif
#if defined(GENIE_MULTISTREAM_FEATURE)
    } else if (item.key() == "multistream") {
      JSON_ENFORCE_OBJECT();
      multistreamConfig = item.value();
      // multistream validation is done below
#endif
    } else if (item.key() == "stop-sequence") {
      JSON_ENFORCE_ARRAY();
      for (auto& elem : item.value()) {
        if (!elem.is_string()) {
          throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                          "stop-sequence must be an array of strings");
        }
      }
    } else if (item.key() == "max-num-tokens") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() < 0) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "number of tokens must be > 0. provided: " + item.value().dump());
      }
    } else if (item.key() == "context") {
      JSON_ENFORCE_OBJECT();
      validateContextConfig(item.value());
    } else if (item.key() == "tokenizer") {
      JSON_ENFORCE_OBJECT();
      validateTokenizerConfig(item.value());
    } else if (item.key() == "sampler") {
      JSON_ENFORCE_OBJECT();
      Sampler::SamplerConfig::validateSamplerConfig(item.value());
    } else if (item.key() == "engine") {
      JSON_ENFORCE_ARRAY_OR_OBJECT();
    } else if (item.key() == "embedding") {
      JSON_ENFORCE_OBJECT();
      validateEmbeddingConfig(item.value());
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown dialog config key: " + item.key());
    }
  }

  // Engine Verification requires dialogType for engine roles. Since "type" is encounterd
  // later than "engine" in loop. Therefore, moving engine validation out of the loop.
  validateMultiEngineConfig(config["engine"], dialogType);

#if defined(GENIE_SSD_FEATURE)
  if (ssdq1) {
    if (!ssdq1Config.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing ssd-q1 dialog config");
    }
    validateDialogSsdConfig(ssdq1Config);
  } else {
    if (ssdq1Config.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "ssd-q1 dialog config for incorrect dialog type: " + dialogType);
    }
  }
#endif
#if defined(GENIE_LADE_FEATURE)
  if (lade) {
    if (!ladeConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lade dialog config");
    }
    validateDialogLadeConfig(ladeConfig);
  } else {
    if (ladeConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "lade dialog config for incorrect dialog type: " + dialogType);
    }
  }
#endif
#if defined(GENIE_SPD_FEATURE)
  if (spd) {
    if (!spdConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing spd dialog config");
    }
    validateDialogSpdConfig(spdConfig);
  } else {
    if (spdConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "spd dialog config for incorrect dialog type: " + dialogType);
    }
  }
#endif
#if defined(GENIE_MULTISTREAM_FEATURE)
  if (multistream) {
    if (!multistreamConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing multistream dialog config");
    }
    validateDialogMultistreamConfig(multistreamConfig);
  } else {
    if (multistreamConfig.is_object()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "multistream dialog config for incorrect dialog type: " + dialogType);
    }
  }
#endif
}

static void translateDialogConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
  if (genieConfig["dialog"]["version"] == 1) {
    if (genieConfig["dialog"]["type"] == "lade") {
      quallaConfig["type"] = "lhd-dec";
    } else if (genieConfig["dialog"]["type"] == "spd") {
      quallaConfig["type"] = "spec-dec";
    } else if (genieConfig["dialog"]["type"] == "multistream") {
      quallaConfig["type"] = "multistream";
    } else {
      quallaConfig["type"] = genieConfig["dialog"]["type"];
    }
#if defined(GENIE_SSD_FEATURE)
    if (genieConfig["dialog"]["type"] == "ssd-q1") {
      quallaConfig["ssd-version"] = genieConfig["dialog"]["ssd-q1"]["ssd-version"];
      quallaConfig["forecast-token-count"] =
          genieConfig["dialog"]["ssd-q1"]["forecast-token-count"];
      quallaConfig["branches"]        = genieConfig["dialog"]["ssd-q1"]["branches"];
      quallaConfig["forecast-prefix"] = genieConfig["dialog"]["ssd-q1"]["forecast-prefix"];
      quallaConfig["forecast-prefix-name"] =
          genieConfig["dialog"]["ssd-q1"]["forecast-prefix-name"];

      if (genieConfig["dialog"]["ssd-q1"].contains("n-streams")) {
        quallaConfig["n-streams"] = genieConfig["dialog"]["ssd-q1"]["n-streams"];
      }
      if (genieConfig["dialog"]["ssd-q1"].contains("p-threshold")) {
        quallaConfig["p-threshold"] = genieConfig["dialog"]["ssd-q1"]["p-threshold"];
      }
    }
#endif
#if defined(GENIE_LADE_FEATURE)
    if (genieConfig["dialog"]["type"] == "lade") {
      quallaConfig["lhd-update-mode"] = genieConfig["dialog"]["lade"]["update-mode"];
      quallaConfig["window"]          = genieConfig["dialog"]["lade"]["window"];
      quallaConfig["ngram"]           = genieConfig["dialog"]["lade"]["ngram"];
      quallaConfig["gcap"]            = genieConfig["dialog"]["lade"]["gcap"];
    }
#endif
#if defined(GENIE_SPD_FEATURE)
    if (genieConfig["dialog"]["type"] == "spd") {
      quallaConfig["draft-len"] = genieConfig["dialog"]["spd"]["draft-len"];
    }
#endif
#if defined(GENIE_MULTISTREAM_FEATURE)
    if (genieConfig["dialog"]["type"] == "multistream") {
      quallaConfig["n-streams"] = genieConfig["dialog"]["multistream"]["n-streams"];
      if (genieConfig["dialog"]["multistream"].contains("p-threshold")) {
        quallaConfig["p-threshold"] = genieConfig["dialog"]["multistream"]["p-threshold"];
      }
    }
#endif
  }
  if (genieConfig["dialog"].contains("stop-sequence")) {
    quallaConfig["prompt"]["stop-sequence"] = genieConfig["dialog"]["stop-sequence"];
  }

  translateContextConfig(genieConfig, quallaConfig);
  translateTokenizerConfig(genieConfig, quallaConfig);
  Sampler::SamplerConfig::translateSamplerConfig(genieConfig, quallaConfig);
  translateMultiEngineConfig(genieConfig, quallaConfig);
  translateEmbeddingConfig(genieConfig, quallaConfig);
}

uint32_t getMaxNumTokens(const qualla::json& genieConfig) {
  uint32_t tokenLimit{UINT32_MAX};
  if (genieConfig["dialog"]["version"] == 1) {
    if (genieConfig["dialog"].contains("max-num-tokens")) {
      tokenLimit = genieConfig["dialog"]["max-num-tokens"];
    }
  }
  return tokenLimit;
}

Dialog::Config::Config(const char* configStr) {
  qualla::json config;
  rope_theta_set   = false;
  position_dim_set = false;
  {
    std::set<qualla::json> keys;

    auto callback = [&keys](int depth, qualla::json::parse_event_t event, qualla::json& parsed) {
      if ((depth == 1) && (event == qualla::json::parse_event_t::key)) {
        if (keys.count(parsed) > 0) {
          throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                          "Multiple dialog config key: " + parsed.dump());
        }
        keys.insert(parsed);
      }
      return true;
    };

    config = qualla::json::parse(configStr, callback);
  }

  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Dialog config is not an object");
  }

  std::set<std::string> mandatoryFields{"dialog"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing dialog field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  std::string component = "dialog";

  for (auto& item : config.items()) {
    if (item.key() == "dialog") {
      JSON_ENFORCE_OBJECT();
      validateDialogConfig(item.value());
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown dialog config key: " + item.key());
    }
  }
  m_config = config;
}

qualla::json& Dialog::Config::getJson() { return m_config; }

//=============================================================================
// Dialog functions
//=============================================================================

qnn::util::HandleManager<Dialog> Dialog::s_manager;
std::atomic<std::uint32_t> Dialog::s_nameCounter{0u};

GenieDialog_Handle_t Dialog::add(std::shared_ptr<Dialog> dialog) {
  return (GenieDialog_Handle_t)s_manager.add(dialog);
}

std::shared_ptr<Dialog> Dialog::get(GenieDialog_Handle_t handle) {
  return s_manager.get((qnn::util::Handle_t)handle);
}

void Dialog::remove(GenieDialog_Handle_t handle) { s_manager.remove((qnn::util::Handle_t)handle); }

Dialog::Dialog(std::shared_ptr<Config> config) {
  auto env = qualla::Env::create(qualla::json{});
  qualla::json quallaConfig;
  translateDialogConfig(config->getJson(), quallaConfig);
  m_tokenLimit   = getMaxNumTokens(config->getJson());
  m_quallaDialog = qualla::Dialog::create(
      env, "dialog" + std::to_string(s_nameCounter.fetch_add(1u)), quallaConfig);
  if (!m_quallaDialog) {
    throw Exception(GENIE_STATUS_ERROR_MEM_ALLOC, "Could not create a dialog object");
  }
  /*
   * spec-dec has a mandatory "target" sampler and an optional "draft" sampler
   * Check their availability and pass their references to Dialog Sampler to update with
   * applyConfig()
   */
  std::shared_ptr<Sampler> sampler;
  std::vector<std::reference_wrapper<qualla::Sampler>> quallaSamplers;
  if (quallaConfig["type"] == "spec-dec") {
    quallaSamplers.push_back(m_quallaDialog->sampler("target"));
    if (m_quallaDialog->isSamplerPresent("draft"))
      quallaSamplers.push_back(m_quallaDialog->sampler("draft"));
    sampler = std::make_shared<Sampler>(config->getJson()["dialog"], quallaSamplers);
  } else {
    quallaSamplers.push_back(m_quallaDialog->sampler());  // Default role is "primary"
    sampler = std::make_shared<Sampler>(config->getJson()["dialog"], quallaSamplers);
  }
  m_samplerHandle = Sampler::add(sampler);
}

GenieSampler_Handle_t Dialog::getSamplerHandle(std::shared_ptr<Dialog> dialog) {
  return dialog->m_samplerHandle;
}

static_assert(qualla::Sentence::Code::COMPLETE ==
              static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_COMPLETE));
static_assert(qualla::Sentence::Code::BEGIN ==
              static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_BEGIN));
static_assert(qualla::Sentence::Code::CONTINUE ==
              static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_CONTINUE));
static_assert(qualla::Sentence::Code::END ==
              static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_END));
static_assert(qualla::Sentence::Code::ABORT ==
              static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_ABORT));

int32_t Dialog::query(const char* queryStr,
                      GenieDialog_SentenceCode_t sentenceCode,
                      GenieDialog_QueryCallback_t callback,
                      const void* userData) {
  std::string query(queryStr);
  uint32_t genTokenCount = 0u;
  bool status            = m_quallaDialog->query(
      query,
      static_cast<qualla::Sentence::Code>(sentenceCode),
      [&](const std::string& response, qualla::Sentence::Code code) {
        callback(response.c_str(), static_cast<GenieDialog_SentenceCode_t>(code), userData);
        bool keepGoing = ++genTokenCount < m_tokenLimit;
        if (!keepGoing && ((code == qualla::Sentence::Code::BEGIN) ||
                           (code == qualla::Sentence::Code::CONTINUE))) {
          callback("", GENIE_DIALOG_SENTENCE_END, userData);
        }
        return keepGoing;
      });
  qualla::Dialog::KPIs kpis = m_quallaDialog->kpis();
  printf(
      "\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
      "%f toks/sec\n"
      "Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n",
      kpis.init.total_usec,
      kpis.prompt.last_usec,
      kpis.tps.prompt,
      kpis.generate.last_usec,
      kpis.tps.generate);
  return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
}

int32_t Dialog::save(const std::string& name) {
  return m_quallaDialog->save(name) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
}

int32_t Dialog::restore(const std::string& name) {
  return m_quallaDialog->restore(name) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
}

#if defined(GENIE_E2T_FEATURE)
int32_t Dialog::embeddingQuery(const void* embeddings,
                               const uint32_t embeddingsSize,
                               GenieDialog_SentenceCode_t sentenceCode,
                               GenieDialog_TokenToEmbeddingCallback_t t2eCallback,
                               GenieDialog_QueryCallback_t callback,
                               const void* userData) {
  uint32_t genTokenCount = 0u;

  if (embeddingsSize % m_quallaDialog->getEmbeddingBufferSize() != 0) {
    throw std::runtime_error(
        "The embeddings buffer size must be an integer multiple of the embedding vector size in "
        "bytes.");
  }

  const uint8_t* embeddingsSrc = static_cast<const uint8_t*>(embeddings);
  std::vector<uint8_t> embeddingVector(embeddingsSrc, embeddingsSrc + embeddingsSize);

  qualla::Dialog::T2ECallback t2eQuallaCallback{nullptr};
  if (t2eCallback) {
    t2eQuallaCallback = [&](const int32_t token, void* embedding, const uint32_t embd_size) {
      t2eCallback(token, embedding, embd_size, userData);
    };
  }

  bool status = m_quallaDialog->query(
      embeddingVector,
      static_cast<qualla::Sentence::Code>(sentenceCode),
      t2eQuallaCallback,
      [&](const std::string& response, qualla::Sentence::Code code) {
        callback(response.c_str(), static_cast<GenieDialog_SentenceCode_t>(code), userData);
        bool keepGoing = ++genTokenCount < m_tokenLimit;
        if (!keepGoing && ((code == qualla::Sentence::Code::BEGIN) ||
                           (code == qualla::Sentence::Code::CONTINUE))) {
          callback("", GENIE_DIALOG_SENTENCE_END, userData);
        }
        return keepGoing;
      });
  qualla::Dialog::KPIs kpis = m_quallaDialog->kpis();
  printf(
      "\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
      "%f toks/sec\n"
      "Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n",
      kpis.init.total_usec,
      kpis.prompt.last_usec,
      kpis.tps.prompt,
      kpis.generate.last_usec,
      kpis.tps.generate);
  return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
}
#endif

void Dialog::reset() { m_quallaDialog->reset(); }

#if defined(GENIE_LORA_FEATURE)

int32_t Dialog::applyLora(std::string loraAdapterName, std::string engine) {
  bool status = m_quallaDialog->applyLoraAdapter(loraAdapterName, engine);
  return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_GENERAL);
}

int32_t Dialog::applyLoraStrength(std::string tensorName, std::string engine, float alpha) {
  bool status = m_quallaDialog->applyLoraStrength(tensorName, alpha, engine);
  return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_GENERAL);
}

#endif

int32_t Dialog::tokenQuery(const uint32_t* tokens,
                           const uint32_t sizeInputTokens,
                           GenieDialog_SentenceCode_t sentenceCode,
                           GenieDialog_TokenQueryCallback_t callback,
                           const void* userData) {
  std::vector<uint32_t> inputTokens;
  for (size_t i = 0; i < sizeInputTokens; i++) {
    inputTokens.push_back(tokens[i]);
  }
  uint32_t genTokenCount = 0u;
  dialogCallback.setCallBackType(qualla::QUALLA_CALLBACK_TYPE_TOKEN);
  dialogCallback.getTokenCbFunc() = std::make_shared<
      std::function<bool(const int32_t*, const uint32_t, qualla::Sentence::Code)>>();
  *(dialogCallback.getTokenCbFunc()) = [&](const int32_t* responseTokens,
                                           const uint32_t sizeResponseTokens,
                                           qualla::Sentence::Code code) {
    callback((const uint32_t*)responseTokens,
             sizeResponseTokens,
             static_cast<GenieDialog_SentenceCode_t>(code),
             userData);
    bool keepGoing = ++genTokenCount < m_tokenLimit;
    if (!keepGoing &&
        ((code == qualla::Sentence::Code::BEGIN) || (code == qualla::Sentence::Code::CONTINUE))) {
      callback(nullptr, 0, GENIE_DIALOG_SENTENCE_END, userData);
    }
    return keepGoing;
  };
  bool status               = m_quallaDialog->query((const std::vector<uint32_t>)inputTokens,
                                      static_cast<qualla::Sentence::Code>(sentenceCode),
                                      dialogCallback);
  qualla::Dialog::KPIs kpis = m_quallaDialog->kpis();
  printf(
      "\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
      "%f toks/sec\n"
      "Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n",
      kpis.init.total_usec,
      kpis.prompt.last_usec,
      kpis.tps.prompt,
      kpis.generate.last_usec,
      kpis.tps.generate);
  return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
}

Dialog::~Dialog() { Sampler::remove(m_samplerHandle); }
