//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#include <vector>
#include <string>

#include <qualla/engine.hpp>
#include <qualla/detail/config.hpp>
#include <qualla/detail/timer.hpp>
#include <qualla/detail/onload.hpp>

#include <fmt/format.h>

#include "cpu-model.hpp"

#define __INFO(__fmt, ...)  _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
#define __WARN(__fmt, ...)  _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
#define __KPIS(__fmt, ...)                                                                         \
    _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __DEBUG(__fmt, ...)                                                                        \
    _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __TRACE(__fmt, ...)                                                                        \
    _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })

namespace qualla {

class QnnCpuEngine : public Engine {
  private:
    // Model parameters
    std::unique_ptr<QnnCpuModel> _model;

  public:
    QnnCpuEngine(Context& ctx, const qualla::json& json);
    ~QnnCpuEngine();

    virtual size_t process(
            const std::vector<int32_t>& tokens,
            std::vector<float>&         logits,
            bool                        logits_all
    ) override;

    virtual size_t process(
            const std::vector<int32_t>& tokens,
            const std::vector<int32_t>& attention_map,
            std::vector<float>&         logits,
            bool                        logits_all
    ) override;

    virtual bool   updateKV(size_t n_past) override;
    virtual bool   updateKV(size_t n_past, const std::vector<bool>& selected) override;
    virtual bool   save(const std::string& name) override;
    virtual size_t restore(const std::string& name, bool chooseHigherVariant) override;
    virtual void   reset() override;
    virtual bool applyLoraAdapter(std::string lora_adapter_name) override;
    virtual bool applyLoraStrength(std::string tensor_name, float tensor_val) override;
};

namespace fs = std::filesystem;

QnnCpuEngine::QnnCpuEngine(Context& ctx, const qualla::json& json) : Engine(ctx, "qnn-cpu", json) {
    qualla::Timer start;

    using FF  = Feature::Flags;
    _features = FF::OUTPUT_LOGITS | FF::SAVE_RESTORE | FF::OUTPUT_EMBEDDINGS;

    __DEBUG("qnn-cpu: init start");

    qualla::Config conf(json, _type + "-engine:");

    // Parse config
    QnnCpuModel::Params p;

    std::string model_output = conf.optional<std::string>("model-output", "logits");
    if (model_output == "logits")
        p.model_output = QnnCpuModel::ModelOutput::LOGITS;
    else if (model_output == "embeddings")
        p.model_output = QnnCpuModel::ModelOutput::EMBEDDINGS;
    else
        throw std::runtime_error(
                "Only logits and embeddings outputs are supported. Invalid output supplied : " +
                model_output
        );

    p.model_basedir  = _env.path().models / conf.optional<std::string>("model-basedir", "");
    p.model_bin_path = conf.mandatory<std::string>("model-bin-path");
    p.model          = conf.mandatory<std::string>("model");
    p.op_package     = conf.mandatory<std::string>("op-package");
    p.backend_lib    = conf.mandatory<std::string>("backend-lib");
    p.n_threads      = conf.optional<uint32_t>("n-threads", 6);
    p.n_logits       = conf.optional<uint32_t>("n_logits", 1);
    p.n_layer        = conf.optional<uint32_t>("n_layer", 32);
    p.n_embd         = conf.optional<uint32_t>("n_embd", 4096);
    p.n_heads        = conf.optional<uint32_t>("n_heads", 32);
    p.use_mmap       = conf.optional<bool>("use-mmap", false);
    p.ctx_size       = _ctx.size();
    p.n_vocab_size   = _ctx.n_vocab();
    p.lora_config_type = LoraConfigType::LORA_DISABLE;
    qualla::json lora_conf = conf.optional<qualla::json>("lora", {});
    if (lora_conf.size() != 0) {
        p.lora_config_type = LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE;
        if (lora_conf.is_array()) {
          for (auto lc : lora_conf) {
            std::string lnm = lc["adapter-name"];
            p.lora_config[lnm].lora_name         = lnm;
            p.lora_config[lnm].alpha_tensor_name = lc["alpha-tensor-name"];
            p.lora_config[lnm].alpha_tensor_val  = 0.0f;
            if(lc.contains("alpha-tensor-value")){
              p.lora_config[lnm].alpha_tensor_val  = lc["alpha-tensor-value"];
            }
            std::string basedir = "";
            if(lc.contains("binsection-basedir")){
              basedir = lc["binsection-basedir"];
            }
            uint32_t n = lc["bin-sections"].size();
            for (uint32_t i = 0; i < n; i++) {
              auto binSec = lc["bin-sections"].get<std::vector<std::string>>();
              fs::path binsection_path = fs::path(binSec[i]);
              if (binsection_path.is_relative()) binsection_path = basedir / fs::path(binSec[i]);
              if (!fs::is_regular_file(binsection_path)) {
                __ERROR("qnn-cpu: Can't access Lora binsection adapter : {}",
                        binsection_path.string());
                throw std::runtime_error(
                    "qnn-cpu: Can't open adapter file : " + binsection_path.string()
                );
              }
              p.lora_config[lnm].binsection_list.push_back(binsection_path.string());
            }
          }
        }
    }
    _model = std::make_unique<QnnCpuModel>(_env, p);

    // Load model
    if (true != _model->initializeModel()) {
        throw std::runtime_error("Failure to initialize model");
    }

    // Initialize IO Tensor buffers
    if (true != _model->initializeIOTensors()) {
        throw std::runtime_error("Error in setting up IO Tensors");
    }

    if (true != _model->validateModel()) {
        throw std::runtime_error("Error validating model. Please check your I/O");
    }

    __DEBUG("qnn-cpu: model has been validated!");

    if (true != _model->initializeTensorPointers()) {
        throw std::runtime_error("Error : Could not find I/O tensors in loaded graphs");
    }

    _kpis.load.update(start.elapsed_usec());
};

QnnCpuEngine::~QnnCpuEngine() {
    __DEBUG("qnn-cpu: fini");
}

bool QnnCpuEngine::updateKV(size_t n_past) {
    qualla::Timer start;

    if (n_past > _ctx.size()) {
        __ERROR("qnn-cpu: context size exceeded : n_past {}", n_past);
        State::error("context size exceeded");
        return false;
    }

    __DEBUG("qnn-cpu: update-kv start : n_past {}", n_past);

    _model->setKVCacheNPast(n_past);

    __DEBUG("qnn-cpu: update-kv complete : {} usec", start.elapsed_usec());

    _kpis.update_kv.update(start.elapsed_usec());

    return true;
}

bool QnnCpuEngine::updateKV(size_t n_past, const std::vector<bool>& selected) {
    qualla::Timer start;

    if (n_past > _ctx.size()) {
        __ERROR("qnn-cpu: context size exceeded : n_past {}", n_past);
        State::error("context size exceeded");
        return false;
    }

    __DEBUG("qnn-cpu: update-kv start : n_past {}", n_past);

    _model->setKVCacheNPast(n_past);

    __DEBUG("qnn-cpu: update-kv complete : {} usec", start.elapsed_usec());

    _kpis.update_kv.update(start.elapsed_usec());

    return true;
}

size_t QnnCpuEngine::process(
        const std::vector<int32_t>& tokens,
        std::vector<float>&         logits,
        bool                        logits_all = false
) {
    qualla::Timer start;

    __DEBUG("qnn-cpu: inference start: n_tokens {}", tokens.size());

    _model->runInference(tokens, logits_all);

    __DEBUG("qnn-cpu: inference complete : {} usec", start.elapsed_usec());

    size_t n_tok;

    {
        qualla::Timer t;

        __DEBUG("qnn-cpu: get-logits start: all {}", logits_all);

        n_tok = _model->getDequantLogits(logits, logits_all);

        __DEBUG("qnn-cpu: get-logits complete : {} usec", t.elapsed_usec());
    }

    _kpis.process.update(start.elapsed_usec());

    return n_tok;
}

size_t QnnCpuEngine::process(
    const std::vector<int32_t>& tokens,
    const std::vector<int32_t>& attention_map,
    std::vector<float>&         logits,
    bool                        logits_all = false
) {
    return process(
        tokens,
        logits,
        logits_all
    );
}

size_t QnnCpuEngine::restore(const std::string& name, bool chooseHigherVariant) {
    fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-cpu", _role);
    return _model->loadKVCache(cache_path.string());
}

bool QnnCpuEngine::save(const std::string& name) {
    fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-cpu", _role);
    return _model->saveKVCache(cache_path.string());
}

void QnnCpuEngine::reset() {
    // It's enough to just drop the KV$
    updateKV(0);
}

// For Lora
bool QnnCpuEngine::applyLoraAdapter(std::string lora_adapter_name) {
  if (!_model) {
    __ERROR("qnn-cpu: applyLoraAdapter failed, model not initialized");
    return false;
  }
  return _model->applyLoraAdapter(lora_adapter_name);
}

bool QnnCpuEngine::applyLoraStrength(std::string tensor_name, float tensor_val) {
  if (!_model) {
    __ERROR("qnn-cpu: applyLoraStrength failed, model not initialized");
    return false;
  }
  return _model->applyLoraStrength(tensor_name, tensor_val);
}

// Registrator instance
static OnLoad regy([]() {
    Engine::__register("qnn-cpu", [](Context& ctx, const json& conf) {
        return (Engine*)new QnnCpuEngine(ctx, conf);
    });
});
void          needQnnCpuEngine() {}

} // namespace qualla
