//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#ifndef QUALLA_DETAIL_BASIC_SAMPLER_HPP
#define QUALLA_DETAIL_BASIC_SAMPLER_HPP

#include <random>
#include <memory>
#include <vector>
#include <string>

#include <qualla/detail/json.hpp>
#include <qualla/context.hpp>

namespace qualla {

class BasicSampler : public Sampler {
  public:
    BasicSampler(Context& ctx, const json& conf);

    virtual int32_t process(std::span<const float> logits) override;
    virtual int32_t process(
            std::span<const float> logits,
            std::vector<float>&    probs_out,
            bool                   tok_out
    ) override;

    virtual std::vector<int32_t> process_multiple(
            std::span<const float>& logits,
            std::vector<float>&     probs,
            int32_t                 num_return
    ) override;

    virtual bool save(const std::string& name) override;
    virtual bool restore(const std::string& name) override;
    virtual void reset() override;
    virtual void applyConfig(const qualla::json& conf) override;

  protected:
    int32_t _process(std::span<const float> logits, std::vector<float>* probs_out, bool samp_tok);
};

} // namespace qualla

#endif // QUALLA_DETAIL_BASIC_SAMPLER_HPP
