//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#include "qualla/detail/timer.hpp"

#include "nsp-model.hpp"
#include "nsp-graph.hpp"

#include <sstream>

#include "fmt/format.h"
#include "fmt/ranges.h"

// Copied from threadpool.cpp
#if defined(_WIN32)
    #define NOGDI
    #include "windows.h"

static int sched_yield(void) {
    Sleep(0);
    return 0;
}
#else
    #include <sched.h>
#endif

#define __INFO(__fmt, ...)  _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
#define __WARN(__fmt, ...)  _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
#define __KPIS(__fmt, ...)                                                                         \
    _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __DEBUG(__fmt, ...)                                                                        \
    _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __TRACE(__fmt, ...)                                                                        \
    _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __KVTRACE(__fmt, ...)                                                                      \
    _env.logger().post(Logger::KVMANAGER_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
namespace qualla {

// GraphVariant is a self-contained graph. Represents one specific QNN Model
GraphVariant::GraphVariant(GraphInfo_t* g_info, Qnn_ContextHandle_t qnn_ctx, int32_t n_ctx, std::map<LayerType, std::string>& layerNames)
    : ctx_size(n_ctx), graph_name(g_info->graphName), graph_info(g_info), context_handle(qnn_ctx), m_layerNames(layerNames) {
    //TRACE("Parsing %s with ctx_size %d", this->graph_name.c_str(), n_ctx);

    for (bool io : {true, false}) {
        uint32_t n_tensors = (io) ? graph_info->numInputTensors : graph_info->numOutputTensors;
        auto     tensor_wrappers = (io) ? graph_info->inputTensors : graph_info->outputTensors;
        auto&    tensor_specs    = (io) ? input_specs : output_specs;
        for (size_t tensor_idx = 0; tensor_idx < n_tensors; tensor_idx++) {

            TensorWrapper& tensor      = tensor_wrappers[tensor_idx];
            std::string    tensor_name = QnnApi::getTensorName(tensor);

            std::vector<size_t> tensor_dims;
            if (!QnnApi::getTensorShape(tensor_dims, tensor))
                throw std::runtime_error("Couldn't get tensor shape : " + tensor_name);
            std::vector<QnnUtils::QuantParam> quantParams;
            if (!QnnApi::getTensorQuantParams(&tensor, quantParams)) {
                quantParams.emplace_back(0, 0);
            }
            tensor_specs[tensor_name] =
                    QnnUtils::Tensor(&tensor, tensor_dims, quantParams);
        }
    }

    n_tokens = static_cast<int32_t>(determineGraphInputSize());
}

// Attempt to determine input size from purely graph IO and context size
// The easiest way is using input_ids. Else, attention_mask/position_ids can also be used
size_t GraphVariant::determineGraphInputSize() {
    QnnUtils::Tensor* tensor;
    if (m_layerNames[LayerType::INPUT] == "inputs_embeds") {
        if (!!(tensor = getInput(m_layerNames[LayerType::ATTN_MASK]))) return tensor->dims.getNumElements() / ctx_size;
    } else {
        if (!!(tensor = getInput(m_layerNames[LayerType::INPUT]))) return tensor->dims.getNumElements();
        // Use past_key_out tensor to find input size
        // The last dimension of past_key_out tensor will always be the input size
        for (auto& [tname, qtensor] : output_specs) {
            if (!tname.starts_with("past_key")) continue;
            return static_cast<size_t>(qtensor.dims.channel);
        }
    }
    throw std::runtime_error("Unexpected model. Couldn't determine m_num_tokens");
}

bool GraphVariant::refreshTensorQuantParams() {
    for (bool io : {true, false}) {
        uint32_t n_tensors = (io) ? graph_info->numInputTensors : graph_info->numOutputTensors;
        auto     tensor_wrappers = (io) ? graph_info->inputTensors : graph_info->outputTensors;
        auto&    tensor_specs    = (io) ? input_specs : output_specs;
        for (size_t tensor_idx = 0; tensor_idx < n_tensors; tensor_idx++) {

            TensorWrapper& tensor      = tensor_wrappers[tensor_idx];
            std::string    tensor_name = QnnApi::getTensorName(tensor);
            std::vector<QnnUtils::QuantParam> quantParams;
            if (!QnnApi::getTensorQuantParams(&tensor, quantParams)) {
                quantParams.emplace_back(0, 0);
            }
            tensor_specs[tensor_name].quantParam = quantParams;
        }
    }
    return true;
}

QnnNspGraph::QnnNspGraph(
        int       idx,
        Env&      env,
        int32_t   n_ctx,
        QnnApi*   qnnApi,
        IOTensor* ioTensor,
        bool      threaded
)
    : _idx(idx), _env(env), ctx_size(n_ctx), g_qnn_api(qnnApi), g_buffer_mgr(ioTensor),
      _threaded(threaded) {

    if (_threaded) {
        _lock    = new std::mutex();
        _lock_cv = new std::condition_variable();
    }
    __DEBUG("qnn-htp: new-NSP-graph : n_ctx {}", n_ctx);
}

QnnNspGraph::~QnnNspGraph() {
    __DEBUG("qnn-htp: del-NSP-graph");
    if (kvmanager != nullptr) delete kvmanager;
    if (_threaded) {
        delete _lock;
        delete _lock_cv;
    }
}

// Parse a loaded GraphInfo_t
bool QnnNspGraph::addGraph(GraphVariant* graph_spec) {
    // TRACE("%d", graph_spec->n_tokens);
    const int32_t n_tok = graph_spec->n_tokens;
    // QNN_DEBUG("Searching for n_tokens=%d count=%lu ctx_size=%d", n_tok, variants.count(n_tok), ctx_size);
    if (variants.find(n_tok) != variants.end()) {
        printAvailableConfigs();
        __ERROR("qnn-htp: addGraph detected duplicate : {} v {}", n_tok, variants[n_tok]->n_tokens);
        throw std::runtime_error("qnn-htp: duplicate graph found, likely overflow occured");
    }

    variants[n_tok] = graph_spec;
    return true;
}

void QnnNspGraph::printAvailableConfigs() {
    std::stringstream config_stream;
    for (auto& [config, _] : variants)
        config_stream << config << ", ";

    __DEBUG("config = [{}]", config_stream.str());
}

void QnnNspGraph::dumpTensors(GraphVariant* const variant, bool mode, int n_inference) const {
    if (n_inference >= 10) return;

    QnnUtils::TensorMap& tensor_specs = (mode) ? variant->input_specs : variant->output_specs;
    std::string prefix = fmt::format("{}/{}/{:03d}", _debug_path, variant->graph_name, n_inference);
    for (auto it = tensor_specs.begin(); it != tensor_specs.end(); ++it) {
        auto        tname = it->first;
        auto        tspec = it->second;
        std::string fname = fmt::format("{}_{}_{}", prefix, (mode) ? "in" : "out", tname);
        __TRACE("Dumping {} from {:p}", fname, g_buffer_mgr->getBuffer(tspec.tensor));
        QnnUtils::writeRawData(g_buffer_mgr->getBuffer(tspec.tensor), tspec.dims.getSize(), fname);
    }
}

bool QnnNspGraph::registerPointerShift(int32_t variant, int32_t ptr_offset) {
    __TRACE("Called QnnNspGraph::registerPointerShift");
    if (_kv_update_method != POINTER_SHIFT) return true;
    if (kvmanager->getNumKVTensors() == 0) return true;
    qualla::Timer start;

    std::map<std::string, std::tuple<int, size_t, size_t>> allocs;

    qualla::GraphVariant* graph_variant = variants.at(variant);
    if (variant == ctx_size) {
        // Re-map AR-c model outputs to initial state
        for (auto& [tname, tspec] : graph_variant->output_specs) {
            if (!tname.starts_with("past_")) continue; // Only process KV$
            auto& [alloc_idx, offset] = tensor_alloc_info->at(tname);
            allocs[tname]             = {alloc_idx, offset, tspec.dims.getAlignedSize()};
        }
    } else {

        // For AR-n models, map input KV$ to appropriate offset
        for (auto& [tname, tspec] : graph_variant->input_specs) {
            if (!tname.starts_with("past_")) continue; // Only process KV$
            auto out_name = tname.substr(0, tname.rfind("_")).append("_out");

            auto& [alloc_idx, offset]  = tensor_alloc_info->at(out_name);
            const bool    is_key       = tname.starts_with("past_key");
            const int32_t extra_offset = ptr_offset * (is_key ? 1 : kvmanager->_n_embed);
            allocs[tname] = {alloc_idx, offset + extra_offset, tspec.dims.getAlignedSize()};
        }
    }

    if (!g_buffer_mgr->mapFusedBufferOffset(
                graph_variant->graph_info, graph_variant->context_handle, allocs
        )) {
        __ERROR("Error mapping tensor to allocation buffers");
        return false;
    }

    __DEBUG("qnn-htp: pointerShift complete : {} usec", start.elapsed_usec());
    return true;
}

void QnnNspGraph::registerKVManager(NewNSPKVManager* mgr) {
    kvmanager = mgr;
    if (mgr->getNumKVTensors() == 0 && _threaded) {
        delete _lock;
        delete _lock_cv;
        _threaded = false;
    }
    mgr->registerPointerOffsetFn([this](int32_t variant, int32_t ptr_offset) {
        return this->registerPointerShift(variant, ptr_offset);
    });
}

bool QnnNspGraph::execute(int n_tokens, int n_inference, int32_t wait_count) {
    GraphVariant* variant = variants.at(n_tokens); // Assume n_tokens exists in variants
    run_wait_time = run_exec_time = 0;             // Clear out the timer

    qualla::Timer timer;

    waitForLock("QnnNspGraph::execute", wait_count, false);
    run_wait_time += timer.elapsed_usec();

    // Register pointer shift
    GraphInfo_t* const graph = variant->graph_info;

    if (_debug_tensors) dumpTensors(variant, true, n_inference); // Dump input tensors

    timer.reset(); // Reset the timer to calculate execution time
    std::map<std::string, std::pair<double, uint16_t>> timeLogs;
    if (!g_qnn_api->graphExecute(
                graph->inputTensors, graph->outputTensors, graph->graphName, timeLogs
        )) {
        __ERROR("qnn-htp: graph-exec failed for {}", graph->graphName);
        return false;
    }

    run_exec_time += timer.elapsed_usec();

    if (_debug_tensors) dumpTensors(variant, false, n_inference); // Dump output tensors

    timer.reset();
    releaseLock("QnnNspGraph::execute");
    run_wait_time += timer.elapsed_usec();
    return true;
}

void QnnNspGraph::waitForLock(std::string requester) {
    if (!_threaded) return;
    __KVTRACE("qnn-lock : graph[{}] requested : {}", _idx, requester);
    _lock->lock();
    __KVTRACE("qnn-lock : graph[{}] locking : {}", _idx, requester);
}

void QnnNspGraph::waitForLock(std::string requester, int32_t wait_counter, bool poll) {
    if (!_threaded) return;
    __KVTRACE("qnn-lock : graph[{}] requested : {} (count={})", _idx, requester, wait_counter);

    if (poll) {
        _lock->lock();
        // Busy wait until a specific update is complete
        while (_counter < wait_counter) {
            _lock->unlock();
            sched_yield();
            _lock->lock();
        }
    } else {
        std::unique_lock lk(*_lock);
        _lock_cv->wait(lk, [&] {
            __KVTRACE("qnn-lock : graph[{}] trying ({} >= {})", _idx, _counter, wait_counter);
            return _counter >= wait_counter;
        });
        lk.release();
    }

    __KVTRACE("qnn-lock : graph[{}] locking : {} (count={})", _idx, requester, wait_counter);
    return;
}

void QnnNspGraph::releaseLock(std::string requester) {
    if (!_threaded) return;
    __KVTRACE("qnn-lock : graph[{}] releasing : {} (count={})", _idx, requester, _counter);
    _lock->unlock();
    _lock_cv->notify_one();
}

void QnnNspGraph::wakeUpLock() {
    if (!_threaded) return;
    _lock_cv->notify_one();
}
} // namespace qualla
