//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#pragma once

#include <atomic>
#include <memory>
#include <functional>

#include "GenieDialog.h"
#include "Util/HandleManager.hpp"
#include "qualla/dialog.hpp"
#include "qualla/DialogCallback.hpp"
#include "Sampler.hpp"

namespace genie {

enum LORA_VERSION : uint8_t {
  GENIE_LORA_VERSION_V1        = 0x1,
  GENIE_LORA_VERSION_V2        = 0x2,
  GENIE_LORA_VERSION_UNDEFINED = 0xFF
};

class Dialog {
 public:
  class Config {
   public:
    static GenieDialogConfig_Handle_t add(std::shared_ptr<Config> config);
    static std::shared_ptr<Config> get(GenieDialogConfig_Handle_t handle);
    static void remove(GenieDialogConfig_Handle_t handle);

    Config(const char* configStr);
    qualla::json& getJson();

   private:
    static qnn::util::HandleManager<Config> s_manager;
    qualla::json m_config;
  };

  static GenieDialog_Handle_t add(std::shared_ptr<Dialog> dialog);
  static std::shared_ptr<Dialog> get(GenieDialog_Handle_t handle);
  static void remove(GenieDialog_Handle_t handle);
  static GenieSampler_Handle_t getSamplerHandle(std::shared_ptr<genie::Dialog> dialog);

  qualla::DialogCallback dialogCallback;

  Dialog(std::shared_ptr<Config> config);
  ~Dialog();

  Dialog(const Dialog&)            = delete;
  Dialog& operator=(const Dialog&) = delete;
  Dialog(Dialog&&)                 = delete;
  Dialog& operator=(Dialog&&)      = delete;

  int32_t query(const char* queryStr,
                GenieDialog_SentenceCode_t sentenceCode,
                GenieDialog_QueryCallback_t callback,
                const void* userData);

  int32_t save(const std::string&);

  int32_t restore(const std::string&);

#if defined(GENIE_E2T_FEATURE)
  int32_t embeddingQuery(const void* embeddings,
                const uint32_t embeddingsSize,
                GenieDialog_SentenceCode_t sentenceCode,
                GenieDialog_TokenToEmbeddingCallback_t t2eCallback,
                GenieDialog_QueryCallback_t callback,
                const void* userData);
#endif



  int32_t tokenQuery(const uint32_t* tokens,
                 const uint32_t sizeInputTokens,
                 GenieDialog_SentenceCode_t sentenceCode,
                 GenieDialog_TokenQueryCallback_t callback,
                 const void* userData);

  void reset();

#if defined(GENIE_LORA_FEATURE)
  int32_t applyLora(std::string loraAdapterName, std::string engine);
  int32_t applyLoraStrength(std::string tensorName, std::string engine, float alpha);
#endif

 private:
  std::unique_ptr<qualla::Dialog> m_quallaDialog;
  uint32_t m_tokenLimit{UINT32_MAX};
  static qnn::util::HandleManager<Dialog> s_manager;
  static std::atomic<std::uint32_t> s_nameCounter;
  GenieSampler_Handle_t m_samplerHandle;
};
}  // namespace genie
