//=============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//=============================================================================

#include <chrono>
#include <exception>
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "GenieCommon.h"
#include "GenieDialog.h"
#include "GenieSampler.h"

std::string config{};
std::string prompt{};
std::string savePath{};
std::string restorePath{};

#if defined(GENIE_LORA_FEATURE)
std::string loraAdapterName{};
std::string loraAlphaName{};
float loraAlphaValue = 1.0f;
#endif

#if defined(GENIE_E2T_FEATURE)
std::shared_ptr<void> embeddingBuffer;
size_t embeddingBufferSize{0};
std::string inputDataType{"N/A"};
double inputScale{1.0};
int32_t inputOffset{0};

std::shared_ptr<void> embeddingLut;
size_t embeddingLutSize{0};
std::string lutDataType{"N/A"};
double lutScale{1.0};
int32_t lutOffset{0};

double requantScale{1.0};
double requantOffset{0};
#endif

std::vector<uint32_t> tokens;

std::unordered_set<std::string> commandLineArguments;
std::unordered_map<std::string, std::pair<bool, bool>> m_options;

bool isSet(const std::string& name) {
  auto sought = m_options.find(name);
  return (sought != m_options.end()) && (sought->second).first;
}

bool isRequired(const std::string& name) {
  auto sought = m_options.find(name);
  return (sought != m_options.end()) && (sought->second).second;
}

void addOption(const std::string& name, bool set, bool isRequired) {
  m_options.emplace(name, std::make_pair(set, isRequired));
}

std::streamsize getFileSize(const std::string& filename) {
  std::ifstream file(filename, std::ios::binary | std::ios::ate);
  return file.tellg();
}

bool checkFileExistsAndReadable(const std::ifstream& fileStream, const std::string& fileName) {
  bool res = fileStream.good();
  if (!res) {
    std::cout << std::setw(24) << "File " << fileName << " doesn't exists or is in bad shape."
              << std::endl;
  }
  return res;
}

void printUsage(const char* program) {
  std::cout << "Usage:\n" << program << " [options]\n" << std::endl;
  std::cout << "Options:" << std::endl;

  int width = 88;

  std::cout << std::left << std::setw(width) << "  -h, --help";
  std::cout << "Show this help message and exit.\n" << std::endl;

  std::cout << std::setw(width) << "  -c CONFIG_FILE or --config CONFIG_FILE";
  std::cout << "Dialog JSON configuration file.\n" << std::endl;

  std::cout << std::setw(width) << "  -p PROMPT or --prompt PROMPT";
  std::cout << "Prompt to query. Mutually exclusive with --prompt_file.\n" << std::endl;

  std::cout << std::setw(width) << "  --prompt_file PATH";
  std::cout << "Prompt to query provided as a file. Mutually exclusive with --prompt." << std::endl;

#if defined(GENIE_LORA_FEATURE)
  std::cout << std::endl;
  std::cout
      << std::setw(width)
      << "  -l ADAPTER_NAME,ALPHA_NAME,ALPHA_VALUE or --lora ADAPTER_NAME,ALPHA_NAME,ALPHA_VALUE";
  std::cout << "Apply a LoRA adapter to a dialog." << std::endl;
  std::cout
      << std::setw(width) << ""
      << "ALPHA_NAME and ALPHA_VALUE are optional parameters, only for setting alpha strength."
      << std::endl;
#endif

#if defined(GENIE_E2T_FEATURE)
  std::cout << std::endl;
  std::cout << std::setw(width) << "  -e PATH or --embedding_file PATH[,TYPE,SCALE,OFFSET]";
  std::cout << "Input embeddings provided as a file. Mutually exclusive with --prompt, "
               "--prompt_file and --tokens_file."
            << std::endl;
  std::cout << std::setw(width) << ""
            << "TYPE, SCALE, and OFFSET are optional parameters representing the model's input "
               "quantization encodings. Required for lookup table requantization."
            << std::endl;
  std::cout << std::setw(width) << ""
            << "Valid values of TYPE are int8, int16, uint8, uint16. The signedness must be "
               "consistent with the lookup table encodings."
            << std::endl;
  std::cout << std::endl;
  std::cout << std::setw(width) << "  -t PATH or --embedding_table PATH[,TYPE,SCALE,OFFSET]";
  std::cout << "Token-to-Embedding lookup table provided as a file. Mutually exclusive with "
               "--prompt and --prompt_file."
            << std::endl;
  std::cout << std::setw(width) << ""
            << "TYPE, SCALE, and OFFSET are optional parameters representing the lookup table's "
               "quantization encodings. Required for lookup table requantization."
            << std::endl;
  std::cout << std::setw(width) << ""
            << "Valid values of TYPE are int8, int16, uint8, uint16. The signedness must be "
               "consistent with the input layer encodings."
            << std::endl;
#endif
  std::cout << std::endl;
  std::cout << std::setw(width) << "  -tok PATH or --tokens_file PATH";
  std::cout << "Input tokens provided as a file. Mutually exclusive with --prompt, --prompt_file "
               "and --embedding_file."
            << std::endl;
  std::cout << std::endl;
  std::cout << std::setw(width) << "  -s PATH or --save PATH";
  std::cout << "Saves the dialog state after the dialog is queried. PATH must be an existing path."
            << std::endl;
  std::cout << std::endl;
  std::cout << std::setw(width) << "  -r PATH or --restore PATH";
  std::cout << "Restores the dialog state before the dialog is queried. PATH must contain a "
               "previous save state."
            << std::endl;
}

std::vector<std::string> split(const std::string& str) {
  std::vector<std::string> words;

  std::string::size_type pos  = 0;
  std::string::size_type prev = 0;
  while ((pos = str.find(',', pos)) != std::string::npos) {
    std::string word = str.substr(prev, pos - prev);
    if (word.length() > 0) {
      words.push_back(word);
    }
    prev = ++pos;
  }
  std::string word = str.substr(prev, pos - prev);
  if (word.length() > 0) {
    words.push_back(word);
  }

  return words;
}

bool parseE2TArguments(const std::string arg,
                       std::string& filename,
                       std::string& dataType,
                       double& scale,
                       int32_t& offset) {
  auto args = split(arg);
  if (args.size() == 1) {
    filename = args[0];
  } else if (args.size() == 4) {
    filename = args[0];
    dataType = args[1];
    if ((dataType != "int8") && (dataType != "uint8") && (dataType != "int16") &&
        (dataType != "uint16")) {
      std::cerr << "ERROR: invalid datatype: " << dataType << std::endl;
      return false;
    }
    try {
      scale  = std::stod(args[2]);
      offset = std::stoi(args[3]);
    } catch (const std::exception& e) {
      std::cerr << "ERROR: Invalid quantization encodings: {" << args[2] << ", " << args[3] << "}"
                << std::endl;
      return false;
    }
  } else {
    std::cerr << "ERROR: Invalid embedding argument: " << arg << std::endl;
    return false;
  }
  return true;
}

bool parseCommandLineInput(int argc, char** argv) {
  bool invalidParam = false;
  std::string arg;
  if (argc == 1) {
    printUsage(argv[0]);
    std::exit(EXIT_SUCCESS);
  }
  for (int i = 1; i < argc; i++) {
    arg = argv[i];
    commandLineArguments.insert(arg);
    if (arg == "-h" || arg == "--help") {
      printUsage(argv[0]);
      std::exit(EXIT_SUCCESS);
    } else if (arg == "-c" || arg == "--config") {
      if (++i >= argc) {
        invalidParam = true;
        break;
      }
      std::ifstream configStream = std::ifstream(argv[i]);

      if (!checkFileExistsAndReadable(configStream,
                                      argv[i])) {  // Error encountered don't go further
        return false;
      }

      std::getline(configStream, config, '\0');
      addOption("--config", true, false);
    } else if (arg == "-s" || arg == "--save") {
      if (++i >= argc) {
        invalidParam = true;
        break;
      }
      savePath = argv[i];
      addOption("--save", true, false);
    } else if (arg == "-r" || arg == "--restore") {
      if (++i >= argc) {
        invalidParam = true;
        break;
      }
      restorePath = argv[i];
      addOption("--restore", true, false);
    } else if (arg == "-p" || arg == "--prompt") {
      if (++i >= argc) {
        invalidParam = true;
        break;
      }
      prompt = argv[i];
      addOption("--prompt", true, false);
    } else if (arg == "--prompt_file") {
      if (++i >= argc) {
        invalidParam = true;
        break;
      }
      std::ifstream promptStream(argv[i]);

      if (!checkFileExistsAndReadable(promptStream, argv[i])) {
        return false;
      }

      std::getline(promptStream, prompt, '\0');
      addOption("--prompt_file", true, false);
#if defined(GENIE_LORA_FEATURE)
    } else if (arg == "-l" || arg == "--lora") {
      if (++i >= argc) {
        invalidParam = true;
        break;
      }

      auto args = split(argv[i]);
      if (args.size() == 1)
        loraAdapterName = args[0];
      else if (args.size() == 3) {
        loraAdapterName = args[0];
        loraAlphaName   = args[1];
        try {
          loraAlphaValue = std::stof(args[2]);
        } catch (const std::exception& e) {
          std::cerr << "ERROR: Invalid LoRA alpha tensor strength: " << args[2] << std::endl;
          printUsage(argv[0]);
          return false;
        }
      } else {
        std::cerr << "ERROR: Invalid --lora argument: " << argv[i] << std::endl;
        printUsage(argv[0]);
        return false;
      }
      addOption("--lora", true, false);
#endif
#if defined(GENIE_E2T_FEATURE)
    } else if (arg == "-e" || arg == "--embedding_file") {
      if (++i >= argc) {
        invalidParam = true;
        break;
      }

      std::string filename;

      if (!parseE2TArguments(argv[i], filename, inputDataType, inputScale, inputOffset)) {
        return false;
      }

      uint32_t fileSize = getFileSize(filename);

      embeddingBuffer     = std::shared_ptr<void>(new int8_t[fileSize]);
      embeddingBufferSize = fileSize;
      std::ifstream embeddingStream(filename, std::ifstream::binary);

      if (!checkFileExistsAndReadable(embeddingStream,
                                      filename)) {  // Error encountered don't go further
        return false;
      }

      embeddingStream.read(static_cast<char*>(embeddingBuffer.get()), fileSize);
      addOption("--embedding_file", true, false);
    } else if (arg == "-t" || arg == "--embedding_table") {
      if (++i >= argc) {
        invalidParam = true;
        break;
      }

      std::string filename;

      if (!parseE2TArguments(argv[i], filename, lutDataType, lutScale, lutOffset)) {
        return false;
      }

      uint32_t fileSize = getFileSize(filename);

      embeddingLut     = std::shared_ptr<void>(new int8_t[fileSize]);
      embeddingLutSize = fileSize;
      std::ifstream embeddingTable(filename, std::ifstream::binary);

      if (!checkFileExistsAndReadable(embeddingTable,
                                      filename)) {  // Error encountered don't go further
        return false;
      }

      embeddingTable.read(static_cast<char*>(embeddingLut.get()), fileSize);
      addOption("--embedding_table", true, false);
#endif
    } else if (arg == "-tok" || arg == "--tokens_file") {
      if (++i >= argc) {
        invalidParam = true;
        break;
      }
      std::ifstream file(argv[i]);
      while (std::getline(file, prompt)) {
        std::istringstream iss(prompt);
        uint32_t token;
        while (iss >> token) {
          tokens.push_back(token);
        }
      }
      addOption("--prompt_file", true, false);
    } else {
      std::cerr << "Unknown option: " << arg << std::endl;
      printUsage(argv[0]);
      return false;
    }
  }
  if (invalidParam) {
    std::cerr << "ERROR: Invalid parameter for argument: " << arg << std::endl;
    printUsage(argv[0]);
    return false;
  }
#if defined(GENIE_E2T_FEATURE)
  if (isSet("--embedding_file")) {
    if (isSet("--prompt") || isSet("--prompt_file") || isSet("--tokens_file")) {
      std::cerr << "ERROR:: Please do not provide a text/token prompt and embedding prompt at the "
                   "same time."
                << std::endl;
      return false;
    }
  } else if (isSet("--embedding_table")) {
    std::cerr << "ERROR:: Please provide an embedding file using --embedding_file." << std::endl;
    return false;
  } else
#endif
      if (isSet("--tokens_file")) {
    if (isSet("--prompt") || isSet("--prompt_file") || isSet("--embedding_file")) {
      std::cerr << "ERROR:: Please do not provide a text prompt/embedding file and tokens file at "
                   "the same time."
                << std::endl;
      return false;
    }
  } else if (!isSet("--prompt") && !isSet("--prompt_file")) {
    std::cerr << "ERROR:: Please provide prompt using --prompt or --prompt_file." << std::endl;
    return false;
  } else if (isSet("--prompt") && isSet("--prompt_file")) {
    std::cerr << "ERROR:: Please provide only one of --prompt or --prompt_file." << std::endl;
    return false;
  }

  return true;
}

void queryCallback(const char* responseStr,
                   const GenieDialog_SentenceCode_t sentenceCode,
                   const void*) {
  switch (sentenceCode) {
    case GENIE_DIALOG_SENTENCE_COMPLETE:
      std::cout << "[COMPLETE]: " << std::flush;
      break;
    case GENIE_DIALOG_SENTENCE_BEGIN:
      std::cout << "[BEGIN]: " << std::flush;
      break;
    case GENIE_DIALOG_SENTENCE_CONTINUE:
      break;
    case GENIE_DIALOG_SENTENCE_END:
      std::cout << "[END]" << std::flush;
      break;
    case GENIE_DIALOG_SENTENCE_ABORT:
      std::cout << "[ABORT]: " << std::flush;
      break;
    default:
      std::cout << "[UNKNOWN]: " << std::flush;
      break;
  }
  if (responseStr) {
    std::cout << responseStr << std::flush;
  }
}

#if defined(GENIE_E2T_FEATURE)
void tokenToEmbedCallback(const int32_t token,
                          void* embedding,
                          const uint32_t embeddingSize,
                          const void* userData) {
  const size_t lutIndex = token * embeddingSize;
  if ((lutIndex + embeddingSize) <= embeddingLutSize) {
    int8_t* embeddingSrc = static_cast<int8_t*>(embeddingLut.get()) + lutIndex;
    int8_t* embeddingDst = static_cast<int8_t*>(embedding);
    std::copy(embeddingSrc, embeddingSrc + embeddingSize, embeddingDst);
  } else {
    std::cerr << "Error: T2E conversion overflow." << std::endl;
  }
}

void calculateRequantEncodings() {
  requantScale  = lutScale / inputScale;
  requantOffset = lutScale * lutOffset / inputScale - inputOffset;
}

template <class F, class T>
void requantEmbedding(F* from, T* to, size_t length) {
  for (int i = 0; i < length; i++) {
    to[i] = static_cast<T>(requantScale * from[i] + requantOffset);
  }
}

template <class F, class T>
void tokenToEmbedRequantCallback(const int32_t token,
                                 void* embedding,
                                 const uint32_t embeddingSize,
                                 const void* userData) {
  const size_t numElements = embeddingSize / sizeof(T);
  const size_t lutIndex    = token * numElements;
  if ((lutIndex + numElements) * sizeof(F) <= embeddingLutSize) {
    F* embeddingSrc = static_cast<F*>(embeddingLut.get()) + (lutIndex);
    T* embeddingDst = static_cast<T*>(embedding);
    requantEmbedding(embeddingSrc, embeddingDst, numElements);
  } else {
    std::cerr << "Error: T2E conversion overflow." << std::endl;
  }
}
#endif

void tokenToTokenCallback(const uint32_t* token,
                          const uint32_t tokensLength,
                          const GenieDialog_SentenceCode_t sentenceCode,
                          const void*) {
  switch (sentenceCode) {
    case GENIE_DIALOG_SENTENCE_COMPLETE:
      std::cout << "[COMPLETE]: " << std::flush;
      break;
    case GENIE_DIALOG_SENTENCE_BEGIN:
      std::cout << "[BEGIN]: " << std::flush;
      break;
    case GENIE_DIALOG_SENTENCE_CONTINUE:
      break;
    case GENIE_DIALOG_SENTENCE_END:
      std::cout << "[END]" << std::flush;
      break;
    case GENIE_DIALOG_SENTENCE_ABORT:
      std::cout << "[ABORT]: " << std::flush;
      break;
    default:
      std::cout << "[UNKNOWN]: " << std::flush;
      break;
  }
  if (token) {
    for (uint32_t i = 0; i < tokensLength; i++) {
      std::cout << token[i] << " " << std::flush;
    }
  }
}
/*
 * This class can be used to update sampler parameters in between queries
 * Usage:
    SamplerConfig sc = SamplerConfig();
    sc.createSamplerConfig(configPath);
    sc.setParam("top-p", "0.8"); // You can refer to sampler.json for the parameters that can be
 updated dialog.getSampler(); dialog.applyConfig(sc());
 */
class SamplerConfig {
 public:
  void createSamplerConfig(const std::string& configPath) {
    std::ifstream confStream(configPath);
    std::string config;
    std::getline(confStream, config, '\0');
    m_config             = config;
    const int32_t status = GenieSamplerConfig_createFromJson(config.c_str(), &m_handle);
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to create sampler config.");
    }
  }

  std::string getConfigString() { return m_config; }

  void setParam(const std::string keyStr, const std::string valueStr) {
    const int32_t status = GenieSamplerConfig_setParam(m_handle, keyStr.c_str(), valueStr.c_str());
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to setParam");
    }
  }

  ~SamplerConfig() {
    const int32_t status = GenieSamplerConfig_free(m_handle);
    if (GENIE_STATUS_SUCCESS != status) {
      std::cerr << "Failed to free the sampler config." << std::endl;
    }
  }

  GenieSamplerConfig_Handle_t operator()() const { return m_handle; }

 private:
  GenieSamplerConfig_Handle_t m_handle = NULL;
  std::string m_config;
};

class Dialog {
 public:
  class Config {
   public:
    Config(const std::string& config) {
      int32_t status = GenieDialogConfig_createFromJson(config.c_str(), &m_handle);
      if ((GENIE_STATUS_SUCCESS != status) || (!m_handle)) {
        throw std::runtime_error("Failed to create the dialog config.");
      }
    }

    ~Config() {
      int32_t status = GenieDialogConfig_free(m_handle);
      if (GENIE_STATUS_SUCCESS != status) {
        std::cerr << "Failed to free the dialog config." << std::endl;
      }
    }

    GenieDialogConfig_Handle_t operator()() const { return m_handle; }

   private:
    GenieDialogConfig_Handle_t m_handle = NULL;
  };

  Dialog(Config config) {
    int32_t status = GenieDialog_create(config(), &m_handle);
    if ((GENIE_STATUS_SUCCESS != status) || (!m_handle)) {
      throw std::runtime_error("Failed to create the dialog.");
    }
  }

  ~Dialog() {
    int32_t status = GenieDialog_free(m_handle);
    if (GENIE_STATUS_SUCCESS != status) {
      std::cerr << "Failed to free the dialog." << std::endl;
    }
  }

  void query(const std::string prompt) {
    int32_t status = GenieDialog_query(m_handle,
                                       prompt.c_str(),
                                       GenieDialog_SentenceCode_t::GENIE_DIALOG_SENTENCE_COMPLETE,
                                       queryCallback,
                                       nullptr);
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to query.");
    }
  }

  void save(const std::string name) {
    int32_t status = GenieDialog_save(m_handle, name.c_str());
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to save.");
    }
  }

  void restore(const std::string name) {
    int32_t status = GenieDialog_restore(m_handle, name.c_str());
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to restore.");
    }
  }

  void getSampler() {
    const int32_t status = GenieDialog_getSampler(m_handle, &m_samplerHandle);
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to get sampler.");
    }
  }

  void applyConfig(GenieSamplerConfig_Handle_t samplerConfigHandle) {
    const int32_t status = GenieSampler_applyConfig(m_samplerHandle, samplerConfigHandle);
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to apply sampler config.");
    }
  }
#if defined(GENIE_E2T_FEATURE)
  void embeddingQuery(const void* embeddings, const uint32_t embeddingsSize) {
    GenieDialog_TokenToEmbeddingCallback_t t2eCallback{nullptr};
    if (embeddingLutSize > 0) {
      calculateRequantEncodings();
      if ((lutDataType == "N/A") && (inputDataType == "N/A")) {
        t2eCallback = tokenToEmbedCallback;
      } else if ((lutDataType == "int8") && (inputDataType == "int16")) {
        t2eCallback = tokenToEmbedRequantCallback<int8_t, int16_t>;
      } else if ((lutDataType == "int16") && (inputDataType == "int8")) {
        t2eCallback = tokenToEmbedRequantCallback<int16_t, int8_t>;
      } else if ((lutDataType == "int16") && (inputDataType == "int16")) {
        t2eCallback = tokenToEmbedRequantCallback<int16_t, int16_t>;
      } else if ((lutDataType == "uint8") && (inputDataType == "uint16")) {
        t2eCallback = tokenToEmbedRequantCallback<uint8_t, uint16_t>;
      } else if ((lutDataType == "uint16") && (inputDataType == "uint8")) {
        t2eCallback = tokenToEmbedRequantCallback<uint16_t, uint8_t>;
      } else if ((lutDataType == "uint16") && (inputDataType == "uint16")) {
        t2eCallback = tokenToEmbedRequantCallback<uint16_t, uint16_t>;
      } else {
        throw std::runtime_error("Unsupported LUT requantization: " + lutDataType + " -> " +
                                 inputDataType);
      }
    }
    int32_t status =
        GenieDialog_embeddingQuery(m_handle,
                                   embeddings,
                                   embeddingsSize,
                                   GenieDialog_SentenceCode_t::GENIE_DIALOG_SENTENCE_COMPLETE,
                                   t2eCallback,
                                   queryCallback,
                                   nullptr);
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to query with embedding.");
    }
  }
#endif

  void tokenQuery(const uint32_t* tokens, const uint32_t tokensSize) {
    GenieDialog_TokenQueryCallback_t tokenCallback{nullptr};
    if (tokensSize > 0) {
      tokenCallback = tokenToTokenCallback;
    }
    int32_t status =
        GenieDialog_tokenQuery(m_handle,
                               tokens,
                               tokensSize,
                               GenieDialog_SentenceCode_t::GENIE_DIALOG_SENTENCE_COMPLETE,
                               tokenCallback,
                               nullptr);
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to query with tokens.");
    }
  }

#if defined(GENIE_LORA_FEATURE)
  void applyLora(const std::string engine, const std::string loraAdapterName) {
    int32_t status = GenieDialog_applyLora(m_handle, engine.c_str(), loraAdapterName.c_str());
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to apply the LoRA adapter.");
    }
  }

  void setLoraStrength(const std::string engine, const std::string tensorName, const float alpha) {
    int32_t status =
        GenieDialog_setLoraStrength(m_handle, engine.c_str(), tensorName.c_str(), alpha);
    if (GENIE_STATUS_SUCCESS != status) {
      throw std::runtime_error("Failed to set the LoRA alpha strength.");
    }
  }
#endif

 private:
  GenieDialog_Handle_t m_handle         = NULL;
  GenieSampler_Handle_t m_samplerHandle = NULL;
};

int main(int argc, char** argv) {
  if (!parseCommandLineInput(argc, argv)) {
    return EXIT_FAILURE;
  }

  std::cout << "Using libGenie.so version " << Genie_getApiMajorVersion() << "."
            << Genie_getApiMinorVersion() << "." << Genie_getApiPatchVersion() << "\n"
            << std::endl;

  try {
    Dialog dialog{Dialog::Config(config)};

#if defined(GENIE_LORA_FEATURE)
    if (loraAdapterName.length() > 0) {
      dialog.applyLora("primary", loraAdapterName);
    }
    if (!loraAlphaName.empty()) {
      dialog.setLoraStrength("primary", loraAlphaName, loraAlphaValue);
    }
#endif
    if (!restorePath.empty()) {
      dialog.restore(restorePath);
    }

#if defined(GENIE_E2T_FEATURE)
    if (embeddingBufferSize != 0) {
      std::cout << "Embedding file size: " << embeddingBufferSize << " bytes" << std::endl;
      std::cout << std::endl;
      dialog.embeddingQuery(embeddingBuffer.get(), embeddingBufferSize);
      std::cout << std::endl;
    } else
#endif
        if (tokens.size() != 0) {
      std::cout << "[PROMPT TOKENS]: ";
      for (int i = 0; i < tokens.size(); ++i) {
        std::cout << tokens[i] << " ";
      }
      std::cout << std::endl;
      dialog.tokenQuery(tokens.data(), tokens.size());
      std::cout << std::endl;
    } else {
      std::cout << "[PROMPT]: " << prompt.c_str() << std::endl;
      std::cout << std::endl;
      dialog.query(prompt);
    }
    if (!savePath.empty()) {
      dialog.save(savePath);
    }
  } catch (const std::exception& e) {
    std::cerr << e.what() << std::endl;
    return EXIT_FAILURE;
  }

  return EXIT_SUCCESS;
}
