//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================
#include <exception>
#include <set>

#include "Exception.hpp"
#include "Macro.hpp"
#include "Sampler.hpp"
#include "qualla/detail/json.hpp"

using namespace genie;

//=============================================================================
// Sampler functions
//=============================================================================

qnn::util::HandleManager<Sampler> Sampler::s_manager;

GenieSampler_Handle_t Sampler::add(std::shared_ptr<Sampler> config) {
  return (GenieSampler_Handle_t)s_manager.add(config);
}

std::shared_ptr<Sampler> Sampler::get(GenieSampler_Handle_t handle) {
  return s_manager.get((qnn::util::Handle_t)handle);
}

void Sampler::remove(GenieSampler_Handle_t handle) {
  s_manager.remove((qnn::util::Handle_t)handle);
}

Sampler::Sampler(qualla::json& origJson,
                 std::vector<std::reference_wrapper<qualla::Sampler>>& quallaSamplers)
    : m_origJson(origJson), m_quallaSamplers(quallaSamplers) {}

void Sampler::applyConfig(qualla::json samplerConfigJson) {
  m_origJson["sampler"]["seed"] = qualla::Config::optional<int32_t>(
      samplerConfigJson["sampler"], "seed", m_origJson["sampler"]["seed"]);
  m_origJson["sampler"]["temp"] = qualla::Config::optional<float>(
      samplerConfigJson["sampler"], "temp", m_origJson["sampler"]["temp"]);
  m_origJson["sampler"]["top-k"] = qualla::Config::optional<size_t>(
      samplerConfigJson["sampler"], "top-k", m_origJson["sampler"]["top-k"]);
  m_origJson["sampler"]["top-p"] = qualla::Config::optional<float>(
      samplerConfigJson["sampler"], "top-p", m_origJson["sampler"]["top-p"]);
  m_origJson["sampler"]["version"] =
      qualla::Config::optional<int32_t>(samplerConfigJson["sampler"], "version", 1);
  m_origJson["sampler"]["type"] = "basic";

#if ENABLE_DEBUG_LOGS
  std::cout << "Updated sampler config: " << std::endl;
  std::cout << "temp: " << m_origJson["sampler"]["temp"].get<double>() << std::endl;
  std::cout << "top-k: " << m_origJson["sampler"]["top-k"] << std::endl;
  std::cout << "top-p: " << m_origJson["sampler"]["top-p"].get<double>() << std::endl;
  std::cout << "seed: " << m_origJson["sampler"]["seed"] << std::endl;
#endif
  // Loop through the live qualla sampler instances and update the parameters
  for (auto& quallaSampler : m_quallaSamplers) {
    quallaSampler.get().applyConfig(m_origJson["sampler"]);
  }
}

//=============================================================================
// Sampler::SamplerConfig functions
//=============================================================================

qnn::util::HandleManager<Sampler::SamplerConfig> Sampler::SamplerConfig::s_manager;

GenieSamplerConfig_Handle_t Sampler::SamplerConfig::add(
    std::shared_ptr<Sampler::SamplerConfig> config) {
  return (GenieSamplerConfig_Handle_t)s_manager.add(config);
}

std::shared_ptr<Sampler::SamplerConfig> Sampler::SamplerConfig::get(
    GenieSamplerConfig_Handle_t handle) {
  return s_manager.get((qnn::util::Handle_t)handle);
}

void Sampler::SamplerConfig::remove(GenieSamplerConfig_Handle_t handle) {
  s_manager.remove((qnn::util::Handle_t)handle);
}

Sampler::SamplerConfig::SamplerConfig(const char* configStr) {
  qualla::json quallaConfig;
  qualla::json config;
  {
    std::set<qualla::json> keys;

    auto callback = [&keys](int depth, qualla::json::parse_event_t event, qualla::json& parsed) {
      if ((depth == 1) && (event == qualla::json::parse_event_t::key)) {
        if (keys.count(parsed) > 0) {
          throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                          "Multiple sampler config key: " + parsed.dump());
        }
        keys.insert(parsed);
      }
      return true;
    };

    config = qualla::json::parse(configStr, callback);
  }

  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Sampler config is not an object");
  }

  if (!config.contains("sampler")) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing field: sampler");
  }

  // component is used in the "ENFORCE" macros
  const std::string component = "sampler";
  for (auto& item : config.items()) {
    if (item.key() == "sampler") {
      JSON_ENFORCE_OBJECT();
      validateSamplerConfig(item.value());
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key());
    }
  }

  if (config["sampler"].contains("seed"))
    quallaConfig["sampler"]["seed"] = config["sampler"]["seed"];
  if (config["sampler"].contains("temp"))
    quallaConfig["sampler"]["temp"] = config["sampler"]["temp"];
  if (config["sampler"].contains("top-k"))
    quallaConfig["sampler"]["top-k"] = config["sampler"]["top-k"];
  if (config["sampler"].contains("top-p"))
    quallaConfig["sampler"]["top-p"] = config["sampler"]["top-p"];
  if (config["sampler"].contains("greedy"))
    quallaConfig["sampler"]["greedy"] = config["sampler"]["greedy"];
  if (config["sampler"].contains("version"))
    quallaConfig["sampler"]["version"] = config["sampler"]["version"];
  else
    quallaConfig["sampler"]["version"] = 1;

  quallaConfig["sampler"]["type"] = "basic";

  m_config = quallaConfig;
}

void Sampler::SamplerConfig::setParam(const std::string& keyStr, const std::string& valueStr) {
  if (!keyStr.empty()) {
    // Case 1: Only the parameter mentioned in keyStr is to be updated by valueStr
    std::set<std::string> validParams = {"seed", "top-p", "top-k", "temp"};
    if (std::find(validParams.begin(), validParams.end(), keyStr) == validParams.end()) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Invalid key obtained: " + keyStr);
    }
    try {
      if (keyStr == "seed")
        m_config["sampler"]["seed"] = std::stoi(valueStr);
      else if (keyStr == "top-p")
        m_config["sampler"]["top-p"] = std::stof(valueStr);
      else if (keyStr == "top-k")
        m_config["sampler"]["top-k"] = std::stof(valueStr);
      else if (keyStr == "temp")
        m_config["sampler"]["temp"] = std::stof(valueStr);
    } catch (const std::invalid_argument& e) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                      "Invalid value obtained: " + valueStr + " for key: " + keyStr);
    }
  } else {
    // Case 2: User has passed entire json as a string in valueStr

    if (valueStr.empty())
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Both keyStr and valueStr cannot be empty");

    qualla::json config = qualla::json::parse(valueStr);
    if (!config.contains("sampler")) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing field: sampler");
    }

    // component is used in the "ENFORCE" macros
    const std::string component = "sampler";
    for (auto& item : config.items()) {
      if (item.key() == "sampler") {
        JSON_ENFORCE_OBJECT();
        validateSamplerConfig(item.value());
      } else {
        throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
                        "Unknown sampler config key: " + item.key());
      }
    }

    m_config["sampler"]["seed"] =
        qualla::Config::optional<int32_t>(config["sampler"], "seed", m_config["sampler"]["seed"]);
    m_config["sampler"]["temp"] =
        qualla::Config::optional<float>(config["sampler"], "temp", m_config["sampler"]["temp"]);
    m_config["sampler"]["top-k"] =
        qualla::Config::optional<size_t>(config["sampler"], "top-k", m_config["sampler"]["top-k"]);
    m_config["sampler"]["top-p"] =
        qualla::Config::optional<float>(config["sampler"], "top-p", m_config["sampler"]["top-p"]);
    m_config["sampler"]["version"] = qualla::Config::optional<int32_t>(
        config["sampler"], "version", m_config["sampler"]["version"]);

    m_config["sampler"]["type"] = "basic";
  }
}

void Sampler::SamplerConfig::validateSamplerConfig(const qualla::json& config) {
  if (!config.is_object()) {
    throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "sampler config is not an object");
  }

  const std::set<std::string> mandatoryFields{"version"};
  for (const auto& field : mandatoryFields) {
    if (!config.contains(field)) {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing sampler field: " + field);
    }
  }

  // component is used in the "ENFORCE" macros
  const std::string component = "sampler";

  for (auto& item : config.items()) {
    if (item.key() == "version") {
      JSON_ENFORCE_NUMERIC();
      if (item.value().get<int>() != 1) {
        throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
                        "Invalid sampler config: unsupported version: " + item.value().dump());
      }
    } else if (item.key() == "seed") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "temp") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "top-k") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "top-p") {
      JSON_ENFORCE_NUMERIC();
    } else if (item.key() == "greedy") {
      JSON_ENFORCE_BOOLEAN();
    } else {
      throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key());
    }
  }
}

void Sampler::SamplerConfig::translateSamplerConfig(const qualla::json& genieConfig,
                                                    qualla::json& quallaConfig) {
  if (genieConfig["dialog"].contains("sampler")) {
    quallaConfig["sampler"]["type"] = "basic";

    if (genieConfig["dialog"]["sampler"].contains("seed")) {
      quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
    }
    if (genieConfig["dialog"]["sampler"].contains("temp")) {
      quallaConfig["sampler"]["temp"] = genieConfig["dialog"]["sampler"]["temp"];
    }

    quallaConfig["sampler"]["role"] = "primary";
#if defined(GENIE_SPD_FEATURE)
    if (genieConfig["dialog"]["type"] == "spd") {
      quallaConfig["sampler"]["role"] = "target";
    }
#endif

    if (genieConfig["dialog"]["sampler"].contains("top-k")) {
      quallaConfig["sampler"]["top-k"] = genieConfig["dialog"]["sampler"]["top-k"];
    }
    if (genieConfig["dialog"]["sampler"].contains("top-p")) {
      quallaConfig["sampler"]["top-p"] = genieConfig["dialog"]["sampler"]["top-p"];
    }
    if (genieConfig["dialog"]["sampler"].contains("greedy")) {
      quallaConfig["sampler"]["greedy"] = genieConfig["dialog"]["sampler"]["greedy"];
    }
    if (genieConfig["dialog"]["sampler"].contains("seed")) {
      quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
    }
  }
}

qualla::json Sampler::SamplerConfig::getJson() const { return m_config; }
