//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All rights reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#include <qualla/dialog.hpp>
#include <qualla/logger.hpp>
#include <qualla/detail/config.hpp>
#include <qualla/detail/timer.hpp>
#include <qualla/detail/sampler-utils.hpp>

#include <algorithm>
#include <functional>
#include <fstream>
#include <string>
#include <unordered_map>
#include <filesystem>
#include <iostream>

#include <fmt/format.h>
#include <fmt/ranges.h>

#define __INFO(__fmt, ...)  _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
#define __WARN(__fmt, ...)  _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
#define __KPIS(__fmt, ...)                                                                         \
    _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __DEBUG(__fmt, ...)                                                                        \
    _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __TRACE(__fmt, ...)                                                                        \
    _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })

namespace fs = std::filesystem;

namespace qualla {

Dialog::Dialog(std::shared_ptr<Env> env, const std::string& name, const qualla::json& json)
    : _env(env) {
    Timer start;



    __DEBUG("dialog-new: {} config {}", name, json.dump());

    using qc = qualla::Config;

    // Create Gpiomarker and reset the gpio status to low
    const qualla::json& gpio_conf = qc::optional<qualla::json>(json, "gpio", {});
    _gpio_marker                  = GpioMarker::create(gpio_conf);

    _gpio_marker->set();

    // Create the context first
    _ctx = Context::create(*_env, name, qc::mandatory<qualla::json>(json, "context"));

    // Parse prompt config
    const qualla::json& pmt_conf = qc::optional<qualla::json>(json, "prompt", {});
    _prompt_type                 = qc::optional<std::string>(pmt_conf, "type", "llama2");
    _sys_tags   = qc::optional<std::vector<std::string>>(pmt_conf, "sys-tags", {"", ""});
    _inst_tags  = qc::optional<std::vector<std::string>>(pmt_conf, "inst-tags", {"", ""});
    _role_tags  = qc::optional<std::vector<std::string>>(pmt_conf, "role-tags", {"", ""});
    _sys_prompt = qc::optional<std::string>(pmt_conf, "sys-prompt", "");

    const std::vector<std::string>& stop_sequence =
            qc::optional<std::vector<std::string>>(pmt_conf, "stop-sequence", {});
    _stop_sequence = SequenceMatchTrie(stop_sequence);
    
    // Create Tokenizer
    // TODO: auto-detect / validate n_vocab with tokenizer vocab
    fs::path tok_path = _env->path().models / qc::mandatory<std::string>(json, "tokenizer");
    _tokenizer        = Tokenizer::create(*_ctx, tok_path);

    // Create Sampler(s)
    auto add_sampler = [&](const qualla::json& j) {
        std::string role = qc::optional<std::string>(j, "role", "primary");
        _sampler[role]   = Sampler::create(*_ctx, j);
    };

    const qualla::json& sam_conf = qc::mandatory<qualla::json>(json, "sampler");
    if (sam_conf.is_array()) {
        for (auto sc : sam_conf) {
            add_sampler(sc);
        }
    } else
        add_sampler(sam_conf);




    // Create Engine(s)
    auto add_engine = [&](const qualla::json& j) {
        std::string role = qc::optional<std::string>(j, "role", "primary");

        _engine[role]    = Engine::create(*_ctx, j);

        using FF = Engine::Feature::Flags;


        if (!_engine[role]->supports(FF::OUTPUT_LOGITS))
            throw std::runtime_error("the engine must output Logits");
    };



    const qualla::json& eng_conf = qc::mandatory<qualla::json>(json, "engine");
  

    if (eng_conf.is_array()) {
        
        for (auto ec : eng_conf) {
            add_engine(ec);
        }
    } else{
        add_engine(eng_conf);

    }

    // Store input type (token, embedding, etc) from the engine.
    // This assumes multi-engine usecases use matching input types.
    m_inputType = _engine.begin()->second->getInputType();

    _kpis.init.update(start.elapsed_usec());
}

Dialog::~Dialog() {}

static bool __no_response_query(const std::string&, Sentence::Code) {
    return false;
}

static bool __no_response_token(const int32_t*, const uint32_t, Sentence::Code) {
    return false;
}

static bool __no_response(const std::string&, Sentence::Code) {
    return false;
}

void Dialog::getTopK(std::vector<float>& logits, std::vector<std::vector<int32_t>>& tokens, size_t topK, float pThreshold, Dialog::Callback callback) {

    auto& sampler = *_sampler["primary"];

    // Sample top-k logits but with a minimum probability threshold
#if defined(__GNUC__) && !defined(__clang__)
    std::span<float> indexed_logits_span(logits);
    IndexedLogits indexed_logits(indexed_logits_span, sampler.rng());
#else
    IndexedLogits indexed_logits(std::span{logits.data(),logits.size()}, sampler.rng());
#endif
    indexed_logits.softmax();
    indexed_logits.topK(topK);

    for (int i = 0; i < topK; i++) {

        _last_tok = indexed_logits.indices[i];

        // Only sample tokens above some probability threshold
        // TODO: Modify sampling algorithm as necessary
        if (indexed_logits.probs[i] < pThreshold) {
            break;
        } else if (_ctx->is_eos(_last_tok)) {
            callback("", Sentence::CONTINUE);
        } else {
            tokens.push_back({_last_tok});
        }
    }
}

bool Dialog::query(const std::string& str, Sentence::Code scode, Dialog::Callback callback) {
    std::vector<int32_t> p_vec; // prompt tokens
    std::string          p_str; // prompt string

    p_vec.reserve(1024);

    if (scode == Sentence::COMPLETE || scode == Sentence::BEGIN) {
        // Reset prompt/gen counts for new query
        _n_prompt    = 0;
        _n_generated = 0;
        _n_previous_prompt    = 0;
        _n_previous_generated = 0;


        if (_last_tok >= 0 && !_ctx->is_eos(_last_tok)) p_vec.push_back(_last_tok);

        p_str = _inst_tags[0];

        if (!_n_queries) {
            // First query. Prepend sys-prompt.
            p_str += _sys_tags[0] + _sys_prompt + _sys_tags[1];
        } else {
            // Add EOS explicitly if the last query was aborted prematurely.
            if (_ctx->eos_tok() >= 0) p_vec.push_back(_ctx->eos_tok());
        }

        // Add BOS
        if (_ctx->bos_tok() >= 0) {
            p_vec.push_back(_ctx->bos_tok());
        }
    }

    // FIXME: make this more generic
    if (_prompt_type == "llama3") {
        p_str += _sys_tags[0] + _role_tags[1] + _sys_tags[1] + str + _inst_tags[2];
    } else {
        p_str += str;
    }

    if (scode == Sentence::COMPLETE || scode == Sentence::END) {
        if (_prompt_type == "llama3") {
            p_str += _sys_tags[0] + _role_tags[2] + _sys_tags[1];
        } else {
            p_str += _inst_tags[1];
        }
    }

    _env->logger().post(Logger::DEBUG, [&]() {
      qualla::json j{{"string", str}, {"prompt", p_str}};
      return fmt::format("dialog-query: {} {}", _ctx->name(), j.dump());
    });

    _n_queries++;

    _tokenizer->encode(p_str, p_vec);

    __DEBUG("dialog-tokens: {} {}", _ctx->name(), p_vec);
    __DEBUG("dialog-text: \"{}\"", p_str);

    if (scode == Sentence::COMPLETE || scode == Sentence::END) {
        // Detect stop sequences here
        if (!_stop_sequence.empty()) {
            _stop_sequence.reset();
            return process(p_vec, [&](const std::string& str, Sentence::Code c) {
              // Check for stop sequence and end inference when stop sequence is found
              if (_stop_sequence.process_next_string(str)) {
                callback(str, c); // Emit sequences until match is complete
                return false;
              }

              // Else, return normal callback function
              return callback(str, c);
            });
        }

        return process(p_vec, callback);
    }

    return process(p_vec, __no_response);
}

bool Dialog::query(const std::vector<uint32_t>& input, Sentence::Code scode, qualla::DialogCallback& callback) {
    std::vector<int32_t> p_vec; // prompt tokens
    p_vec.reserve(1024);

    if (scode == Sentence::COMPLETE || scode == Sentence::BEGIN) {
        // Reset prompt/gen counts for new query
        _n_prompt = 0;
        _n_generated = 0;
        _n_previous_prompt = 0;
        _n_previous_generated = 0;

        if (_last_tok >= 0)
            p_vec.push_back(_last_tok);

        // Add EOS explicitly if the last query was aborted prematurely.
        if (_n_queries && _last_tok != _ctx->eos_tok()) {
            p_vec.push_back(_ctx->eos_tok());
        }
        // Add BOS
        if (_ctx->bos_tok() >= 0) {
            p_vec.push_back(_ctx->bos_tok());
        }
    }

    p_vec.insert(p_vec.end(), input.begin(), input.end());
    __DEBUG("dialog-tokens: {} {}", _ctx->name(), p_vec);

    _n_queries++;

    if (scode == Sentence::COMPLETE || scode == Sentence::END) {
        return process(p_vec, callback);
    }

    DialogCallback callback_return_token(QUALLA_CALLBACK_TYPE_TOKEN);
    *(callback_return_token.getTokenCbFunc()) = __no_response_token;
    return process(p_vec, callback_return_token);
}

bool Dialog::query(
        std::vector<uint8_t>& embedding_vectors,
        Sentence::Code        scode,
        T2ECallback           t2eCallback,
        Dialog::Callback      callback
) {
    _n_queries++;
    if (scode == Sentence::COMPLETE || scode == Sentence::END) {
        return process(embedding_vectors, t2eCallback, callback);
    }
    // Only process, no output
    return process(embedding_vectors, t2eCallback, [&](const std::string&, Sentence::Code) {
        return false;
    });
}

bool Dialog::prime(const std::string& str) {
    bool r = query(str, Sentence::COMPLETE, __no_response);

    // End with EOS as we want the primer to be self-contained
    _last_tok = _ctx->eos_tok();

    return r;
}

bool Dialog::save(const std::string& o_name) {
    Timer start;

    // Save using session name unless override is provided
    std::string name      = o_name.empty() ? _ctx->name() : o_name;
    fs::path    save_path = name;

    if (!_n_past) {
        __ERROR("dialog-save: {} : nothing to save yet", name);
        return false;
    }

    __INFO("dialog-save: saving as {} {}", name, save_path.string());

    if (!fs::exists(save_path) && !fs::create_directories(save_path)) {
        __ERROR("dialog-save: {} : failed to create cache directory", name);
        return false;
    }

    // Save Dialog state
    qualla::json j{
            {"n-past", _n_past},
            {"n-prompt", _n_prompt},
            {"n-generated", _n_generated},
            {"n-queries", _n_queries},
            {"last-tok", _last_tok}
    };
    {
        fs::path      p = save_path / "dialog.json";
        std::ofstream f(p);
        f << j;
    }

    // Save Engines (mandatory)
    for (auto& e : _engine) {
        if (!e.second->save(name)) {
            __ERROR("dialog-save: {} : unable to save {} engine", name, e.first);
            return false;
        }
    }

    // Save Samplers (optional)
    for (auto& s : _sampler) {
        if (!s.second->save(name)) {
            __WARN("dialog-save: {} : unable to save {} sampler", name, s.first);
        }
    }

    _kpis.save.update(start.elapsed_usec());

    return true;
}

bool Dialog::restore(const std::string& o_name) {
    Timer start;

    // Restore using session name unless override is provided
    std::string name      = o_name.empty() ? _ctx->name() : o_name;
    fs::path    restore_path = name;

    __INFO("dialog-restore: restoring from {} {}", name, restore_path.string());

    // Try to restore the Dialog state (optional)
    // If this fails we reset everything and try to restore the engine.
    qualla::json j{};
    {
        fs::path p = restore_path / "dialog.json";
        if (fs::exists(p)) {
            std::ifstream f(p);
            j = qualla::json::parse(f);
        } else {
            __DEBUG("dialog-restore: {} : internal state not restored", name);
        }
    }

    using qc     = qualla::Config;
    _n_past      = qc::optional<uint32_t>(j, "n-past", 0);
    _n_prompt    = qc::optional<uint32_t>(j, "n-prompt", 0);
    _n_generated = qc::optional<uint32_t>(j, "n-generated", 0);
    _n_queries   = qc::optional<uint32_t>(j, "n-queries", 1);
    _last_tok    = qc::optional<int32_t>(j, "last-tok", _ctx->eos_tok());

    // Restore Engines (mandatory)
    for (auto& e : _engine) {
        uint32_t n = e.second->restore(name);
        if (!n) {
            __ERROR("dialog-restore: {} : unable to restore {} engine", name, e.first);
            return false;
        }

        // Restore n_past from the engine state
        if (_n_past && n != _n_past) {
            __WARN("dialog-restore: {} : n-past mismatch : {} engine {} intern {}",
                   name,
                   e.first,
                   _n_past,
                   n);
            // Keep the smaller number
            _n_past = std::min(n, _n_past);
        } else
            _n_past = n;
    }

    // Restore Samplers (optional)
    for (auto& s : _sampler) {
        if (!s.second->restore(name)) {
            __WARN("dialog-restore: {} : unable to restore {} sampler", name, s.first);
        }
    }

    _kpis.reset();
    _kpis.restore.update(start.elapsed_usec());

    return true;
}

void Dialog::reset() {
    __INFO("dialog-reset: {}", _ctx->name());

    _n_past      = 0;
    _n_prompt    = 0;
    _n_generated = 0;
    _n_queries   = 0;
    _last_tok    = -1;
    _n_previous_prompt    = 0;
    _n_previous_generated = 0;

    _kpis.reset();

    // Reset Engines and Samplers
    for (auto& e : _engine)
        e.second->reset();
    for (auto& s : _sampler)
        s.second->reset();

    State::clear();
}

// Dialog KPIs helpers

// Get latest KPIs
Dialog::KPIs& Dialog::kpis() {
    // Update TPS
    if (_n_prompt) {
        float t            = _kpis.prompt.last_usec / _n_prompt;
        _kpis.tps.n_prompt = _n_prompt;
        _kpis.tps.prompt   = 1000000.0 / (t ? t : 1000000.0);
    }

    if (_n_generated) {
        float t              = _kpis.generate.last_usec / _n_generated;
        _kpis.tps.n_generate = _n_generated;
        _kpis.tps.generate   = 1000000.0 / (t ? t : 1000000.0);
    }

    // We could synthesize more KPIs from from other layers (engine, sampler, etc)
    return _kpis;
}

std::string Dialog::KPIs::dump(std::string_view sep) const {
    return fmt::format(
            "init:[{}]{}prompt:[{}]{}generate:[{}]{}save:[{}]{}restore:[{}]{} tps-prompt:{:.2f} tps-generate:{:.2f}",
            init.dump(),
            sep,
            prompt.dump(),
            sep,
            generate.dump(),
            sep,
            save.dump(),
            sep,
            restore.dump(),
            sep,
            tps.prompt,
            tps.generate
    );
}

void Dialog::KPIs::reset() {
    init.reset();
    prompt.reset();
    generate.reset();
    save.reset();
    restore.reset();
    tps.prompt   = 0.0;
    tps.generate = 0.0;
}

// Create API

// Dialog registry : type string + creator function
using Registry = std::unordered_map<std::string, Dialog::Creator>;
static std::unique_ptr<Registry> registry;

void Dialog::__register(const std::string& type, Creator func) {
    if (!registry) registry = std::make_unique<Registry>();

    Registry& r = *registry;


    r[type]     = func;
}

std::unique_ptr<Dialog> Dialog::create(
        std::shared_ptr<Env> env,
        const std::string&   name,
        const qualla::json&  conf
) {
    
    using qc         = qualla::Config;
    std::string type = qc::optional<std::string>(conf, "type", "basic");

    if (!registry) throw std::runtime_error(type + ": dialog not found");

    Registry& r = *registry;

    if (!r.contains(type)) throw std::runtime_error(type + ": dialog not found");

    if (!r.contains(type)) {
        throw std::runtime_error(type + ": dialog not found");
    }

    return std::unique_ptr<Dialog>(r[type](env, name, conf));
}

std::unique_ptr<Dialog> Dialog::create(
        std::shared_ptr<Env> env,
        const std::string&   name,
        std::istream&        json_stream
) {

    return create(env, name, json::parse(json_stream));
}

std::unique_ptr<Dialog> Dialog::create(
        std::shared_ptr<Env> env,
        const std::string&   name,
        const fs::path&      json_path
) {

    if (!fs::exists(json_path))
        throw std::runtime_error(json_path.string() + ": file does not exist");
    std::ifstream ifs(json_path);
    return create(env, name, ifs);
}

std::vector<std::string> Dialog::list() {
    std::vector<std::string> v;
    if (!registry) return v;

    Registry& r = *registry;

    for (auto k : r)
        v.push_back(k.first);
    v.push_back("basic"); // default type, always registered
    return v;
}

bool Dialog::applyLoraAdapter(std::string lora_adapter_name, std::string engine_role) {
    auto& engine = *_engine[engine_role];
    if (!engine.applyLoraAdapter(lora_adapter_name)) {
        __WARN("dialog-applyLoraAdapter: failed for {}", lora_adapter_name);
        return false;
    }
    return true;
}
bool Dialog::applyLoraStrength(std::string tensor_name, float tensor_val, std::string engine_role) {
    auto& engine = *_engine[engine_role];
    if (!engine.applyLoraStrength(tensor_name, tensor_val)) {
        __WARN("dialog-applyLoraStrength: failed for {}", tensor_name);
        return false;
    }
    return true;
}

} // namespace qualla
