//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#ifndef QUALLA_DETAIL_SAMPLER_UTILS_HPP
#define QUALLA_DETAIL_SAMPLER_UTILS_HPP

#ifdef _MSC_VER
    #pragma warning(disable : 4068)
#endif

#include <qualla/detail/preproc.hpp>

#include <functional>
#include <string>
#include <random>
#include <span>

namespace qualla {

typedef std::mt19937 rng_t;

// Various sampling utilities.

static double sampleFromUniform(rng_t& rng) {
    int    a      = rng() >> 5;
    int    b      = rng() >> 6;
    double sample = (a * 67108864.0 + b) / 9007199254740992.0;
    return sample;
}

static double sampleFromGumbel(rng_t& rng) {
    double tiny    = 1.1754943508222875e-38;
    double eps     = 1.1920928955078125e-07;
    double uniform = sampleFromUniform(rng);
    double gumbel  = -std::log(-std::log(tiny + uniform * (1. - eps - tiny)));
    return gumbel;
}

// Returns the index of an element chosen by applying the given probability distribution.
template <typename T>
static int32_t sampleFromProbs(const std::span<T> probs, rng_t& rng) {
    static_assert(std::is_floating_point<T>::value);
    std::discrete_distribution<> dist(probs.begin(), probs.end());
    return dist(rng);
}

// Returns the index of the element chosen by the Gumbel max algorithm
template <typename T>
static int32_t sampleUsingGumbelMax(const std::span<T> log_probs, rng_t& rng) {
    static_assert(std::is_floating_point<T>::value);
    int32_t max_purturbed_logit = std::numeric_limits<int32_t>::min();
    int32_t max_idx             = 0;

    for (int32_t i = 0; i < log_probs.size(); i++) {
        float purturbed_logit = log_probs[i] + sampleFromGumbel(rng);
        if (purturbed_logit > max_purturbed_logit) {
            max_purturbed_logit = purturbed_logit;
            max_idx             = i;
        }
    }
    return max_idx;
}

// Add gumbel noise to a set of logits
template <typename T>
void addGumbelNoise(std::vector<T>& log_probs, rng_t& rng) {
    static_assert(std::is_floating_point<T>::value);
    for (int32_t i = 0; i < log_probs.size(); i++) {
        log_probs[i] = log_probs[i] + sampleFromGumbel(rng);
    }
}

// Returns the index of the top token.
template <typename T>
static int32_t argmax(const std::span<T> probs) {
    static_assert(std::is_floating_point<T>::value);
    auto   result = std::max_element(probs.begin(), probs.end());
    size_t id     = std::distance(probs.begin(), result);

    return int32_t(id);
}

// A wrapper around a vector of logits that also keeps track of indices
struct IndexedLogits {
    std::mt19937&          rng;
    std::span<const float> logits;
    std::vector<float>     probs;
    std::vector<int32_t>   indices;
    bool                   probs_valid;
    bool                   sorted;

    IndexedLogits(std::span<const float> logits, std::mt19937& r)
        : rng(r), logits(logits), probs(logits.size(), 0.f), indices(logits.size()),
          probs_valid(false), sorted(false) {
        std::iota(indices.begin(), indices.end(), 0);
    }

    size_t size(void) const { return logits.size(); }

    // Performs a partial sort or a full sort depending on k.
    size_t sort(size_t k = 0) {
        size_t logits_size = logits.size();

        k = k == 0 ? logits_size : k;
        k = std::min(k, logits_size);

        std::partial_sort(
                indices.begin(),
                indices.begin() + k,
                indices.end(),
                [this](int32_t a, int32_t b) { return logits[a] > logits[b]; }
        );

        // FIXME: avoid overwriting input logits

        if (probs_valid) {
            std::vector<float> tmp(k);
            std::vector<float> tmpf(k);
            for (int32_t i = 0; i < k; i++) {
                tmp[i]  = logits[indices[i]];
                tmpf[i] = probs[indices[i]];
            }
            memcpy(const_cast<float*>(logits.data()), tmp.data(), k * sizeof(float));
            memcpy(probs.data(), tmpf.data(), k * sizeof(float));
        } else {
            std::vector<float> tmp(k);
            for (int32_t i = 0; i < k; i++) {
                tmp[i] = logits[indices[i]];
            }
            memcpy(const_cast<float*>(logits.data()), tmp.data(), k * sizeof(float));
        }
        sorted = true;
        return k;
    }

    // Does softmax in-place given a set of logits and a scaling temperature.
    void softmax(float temp = 1.f) {
        QUALLA_ASSERT(temp > 0.f);

        float max_logit;

        if (sorted) {
            max_logit = logits[0];
        } else {
            auto max_iter = std::max_element(logits.begin(), logits.end());
            max_logit     = *max_iter;
        }

        float max_scaled = max_logit / temp;
        float sum_exp    = 0.0f;

#pragma clang loop vectorize(enable)
        for (size_t i = 0; i < logits.size(); i++) {
            float p  = std::exp((logits[i] / temp) - max_scaled);
            probs[i] = p;
            sum_exp += p;
        }

#pragma clang loop vectorize(enable)
        for (size_t i = 0; i < logits.size(); i++) {
            probs[i] /= sum_exp;
        }

        probs_valid = true;
    }

    void logSoftmax(float temp = 1.f) {
        QUALLA_ASSERT(temp > 0.f);
        float max_logit;

        if (sorted) {
            max_logit = logits[0];
        } else {
            auto max_iter = std::max_element(logits.begin(), logits.end());
            max_logit     = *max_iter;
        }

        // log(e^x / sum(e^x)) -> log(e^x) - log(sum(e^x))
        // We're still using the probs vector, despite the outputs technically
        // being log probabilities.

        float max_scaled = max_logit / temp;
        float sum_exp    = 0.0f;

#pragma clang loop vectorize(enable)
        for (size_t i = 0; i < logits.size(); i++) {
            float p  = (logits[i] / temp) - max_scaled;
            probs[i] = p;
            sum_exp += std::exp(p);
        }

        float log_sum_exp = std::log(sum_exp);
#pragma clang loop vectorize(enable)
        for (size_t i = 0; i < logits.size(); i++) {
            probs[i] -= log_sum_exp;
        }

        probs_valid = true;
    }

    // Performs top-k
    void topK(int32_t k) {
        QUALLA_ASSERT(k > 0);
        k = this->sort(k);

        logits = logits.subspan(0, k);
        probs.resize(k);
        indices.resize(k);
    }

    // Performs top-p in-place.
    // Sorts the logits/probabilities if not sorted already.
    void topP(float p, int32_t min_keep = 1) {
        if (p >= 1) return;

        if (!sorted) this->sort();
        if (!probs_valid) this->softmax();

        // Compute the cumulative probabilities
        float  cum_sum   = 0.0;
        size_t last_idx  = logits.size() - 1;
        size_t n_to_trim = 0;

        for (size_t i = last_idx; i > 0; --i) {
            cum_sum += probs[i];
            if (cum_sum <= 1.0 - p) {
                n_to_trim++;
            } else {
                break;
            }
        }

        size_t n_remain = logits.size() - n_to_trim;
        if (n_remain < min_keep) {
            n_remain += min_keep - n_remain;
        }

        logits = logits.first(n_remain);
        probs.resize(n_remain);
        indices.resize(n_remain);

        // The probabilities no longer add up to 1.
        probs_valid = false;
    }

    // Greedy sampling
    int32_t sampleGreedyUnsorted() {
        auto   result = std::max_element(logits.begin(), logits.end());
        size_t id     = std::distance(logits.begin(), result);

        std::fill_n(probs.begin(), probs.size(), (float)0);
        probs[id] = 1.0;

        probs_valid = true;
        return int32_t(id);
    }

    // Sampling from prob distribution
    int32_t sampleFromProbs() {
        QUALLA_ASSERT(probs_valid);
        int32_t idx = qualla::sampleFromProbs<float>(std::span{probs.data(),probs.size()}, rng);
        return int32_t(indices[idx]);
    }

    // Sampling with Gumbel Max
    int32_t sampleUsingGumbelMax() {
        QUALLA_ASSERT(probs_valid);
        // probs here must be log-probabilities
        int32_t idx = qualla::sampleUsingGumbelMax<float>(std::span{probs.data(),probs.size()}, rng);
        return int32_t(indices[idx]);
    }

    // add gumbel noise to the logits
    bool addGumbelNoise() {
        // probs here must be log-probabilities
        qualla::addGumbelNoise<float>(probs, rng);
        return true;
    }
};

} // namespace qualla

#endif // QUALLA_DETAIL_SAMPLER_UTILS_HPP
