//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#include <qualla/sampler.hpp>
#include <qualla/detail/config.hpp>

#include <functional>
#include <iostream>
#include <string>
#include <unordered_map>
#include <span>

#include <fmt/format.h>
#include <fmt/ranges.h>

namespace qualla {

Sampler::Sampler(Context& ctx, const std::string& type, const qualla::json& conf)
    : _type(type), _ctx(ctx), _env(ctx.env()) {
    _env.logger().debug(
            fmt::format("sampler-new: {} ctx {} config {}", type, ctx.name(), conf.dump())
    );

    // Parse config
    using qc = qualla::Config;

    _role  = qc::optional<std::string>(conf, "role", "primary");
    _seed  = qc::optional<int32_t>(conf, "seed", -1);
    _temp  = qc::optional<float>(conf, "temp", 0.1);
    _top_k = qc::optional<size_t>(conf, "top-k", 0);
    _top_p = qc::optional<float>(conf, "top-p", 0.8);

    _greedy = (_temp <= 0.f || _top_k == 1);
    _greedy = qc::optional<bool>(conf, "greedy", _greedy);

    _gumbel = qc::optional(conf, "use-gumbel", false);
    _gumbel = qc::optional(conf, "gumbel", _gumbel);
}

Sampler::~Sampler() {}

bool Sampler::restore(const std::string& name) {
    _env.logger().warn(fmt::format("{}-sampler does not support restore", _type));
    return false;
}

bool Sampler::save(const std::string& name) {
    _env.logger().warn(fmt::format("{}-sampler does not support save", _type));
    return false;
}

void Sampler::reset() {
    _env.logger().warn(fmt::format("{}-sampler does not support reset", _type));
}

int32_t Sampler::process(const std::vector<float>& logits) {
    return process(std::span{logits.data(),logits.size()});
}

int32_t Sampler::process(std::span<const float> logits, std::vector<float>& probs, bool out_tok) {
    _env.logger().warn(fmt::format("{}-sampler does not support probs output", _type));
    return -1;
}

int32_t Sampler::process(
        const std::vector<float>& logits,
        std::vector<float>&       probs,
        bool                      out_tok
) {
    return process(std::span{logits.data(),logits.size()}, probs, out_tok);
}

std::vector<int32_t> Sampler::process_multiple(
        std::span<const float>& logits,
        std::vector<float>&     probs,
        int32_t                 num_return
) {
    _env.logger().warn("sampler does not support num_return");
    return {-1};
}

void Sampler::applyConfig(const qualla::json& conf) {
  _env.logger().warn(fmt::format("Basic sampler supports this for now"));
}

// Sampler registry

using Registry = std::unordered_map<std::string, Sampler::Creator>;
static std::unique_ptr<Registry> registry;

void Sampler::__register(const std::string& type, Creator func) {
    if (!registry) {
        registry = std::make_unique<Registry>();
    }

    Registry& r = *registry;
    r[type]     = func;
}

std::unique_ptr<Sampler> Sampler::create(Context& ctx, const qualla::json& conf) {
    using qc         = qualla::Config;
    std::string type = qc::optional<std::string>(conf, "type", "basic");

    if (!registry) throw std::runtime_error(type + ": sampler not found");

    Registry& r = *registry;

    if (!r.contains(type)) throw std::runtime_error(type + ": sampler not found");

    return std::unique_ptr<Sampler>(r[type](ctx, conf));
}

std::unique_ptr<Sampler> Sampler::create(Context& ctx, std::istream& json_stream) {
    return create(ctx, json::parse(json_stream));
}

std::unique_ptr<Sampler> Sampler::create(Context& ctx, const std::string& json_str) {
    return create(ctx, json::parse(json_str));
}

std::vector<std::string> Sampler::list() {
    std::vector<std::string> v;

    if (!registry) return v;

    Registry& r = *registry;
    for (auto k : r)
        v.push_back(k.first);
    return v;
}

} // namespace qualla
