//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All rights reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#include <qualla/dialog.hpp>
#include <qualla/sampler.hpp>
#include <qualla/logger.hpp>
#include <qualla/detail/config.hpp>
#include <qualla/detail/timer.hpp>
#include <qualla/detail/onload.hpp>
#include <qualla/detail/sampler-utils.hpp>
#include <qualla/detail/basic-sampler.hpp>
#include <qualla/detail/cache-file.hpp>

#include <functional>
#include <fstream>
#include <string>
#include <unordered_map>
#include <filesystem>
#include <random>

#include <fmt/format.h>
#include <fmt/ranges.h>

namespace fs = std::filesystem;

#define __INFO(__fmt, ...)  _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
#define __WARN(__fmt, ...)  _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
#define __KPIS(__fmt, ...)                                                                         \
    _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __DEBUG(__fmt, ...)                                                                        \
    _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __TRACE(__fmt, ...)                                                                        \
    _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })

namespace qualla {

    using qc = qualla::Config;

    class KvShareDialog : public Dialog {
    public:
        KvShareDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf)
                : Dialog(env, name, conf) {}

        virtual bool process(std::vector<int32_t>& tokens, Dialog::Callback callback) override;

        virtual bool process(std::vector<int32_t>& tokens, DialogCallback callback) override {
            return false;
        }

        virtual void reset() override;

        bool convertKV(const fs::path& cache_dir);

    };

    void KvShareDialog::reset() {
      __INFO("dialog-reset: {}", _ctx->name());

      _n_past      = 0;
      _n_prompt    = 0;
      _n_generated = 0;
      _n_queries   = 0;
      _last_tok    = -1;

      _kpis.reset();

      // Reset Samplers
      for (auto& s : _sampler)
        s.second->reset();

      // Reset Engines
      for (auto& e : _engine) {
        e.second->reset();
        e.second->unload();
      }

      State::clear();
    }

    bool KvShareDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {

      // Check for prev failures and bail out early
      if (State::failed()) return false;

      Timer start;

      // Vector for storing logits.
      // Allocated & filled by the engine.
      std::vector<float> logits;

      State::clear();

      auto& sampler = *_sampler["primary"];

      auto& p_engine = *_engine["primary"];   // prompt
      auto& s_engine = *_engine["secondary"]; // generation

      if (_n_past + tokens.size() > _ctx->size()) {
        __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
        callback("", Sentence::END);
        return true;
      }

      if (!p_engine.process(tokens, logits))
        return Dialog::abort("engine prompt processing failed", callback);

      _n_prompt += tokens.size();
      _n_past += tokens.size();

      if (!p_engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);

      tokens[0] = _last_tok = sampler.process(logits);
      tokens.resize(1);

      _n_generated++;

      _kpis.prompt.update(start.elapsed_usec());
      // Log latest KPIs
      _env->logger().post(Logger::KPIS, kpis().dump(" "));

      if (_ctx->is_eos(_last_tok)) {
        callback("", Sentence::END);
        return true;
      }

      if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true;

      __DEBUG("dialog: {} : switching engines", _ctx->name());
      {
        // Setup cache dir for saving the engine state
        std::string cache_name = _ctx->name() + "-kv-share";
        fs::path    cache_dir  = _env->path().cache / cache_name;

        if (!fs::exists(cache_dir) && !fs::create_directories(cache_dir)) {
          __ERROR("dialog: {} : failed to create cache directory {}",
                                    _ctx->name(),
                                    cache_dir.string());
          return Dialog::abort("engine switch failed", callback);
        }

        // Save and unload the primary engine
        p_engine.save(cache_name);
        p_engine.unload();

        // The purpose is to save the hyperparams
        s_engine.save(cache_name);

        convertKV(cache_dir);

        size_t n = s_engine.restore(cache_name);

        if(!fs::remove_all(cache_dir)) {
          __WARN("dialog: {} : cache files not closed/dir not found", _ctx->name());
        }

        if (n != _n_past) {
          __WARN("dialog: {} : kv size mismatch {} expected {}", _ctx->name(), n, _n_past);
          _n_past = n;
        }

        s_engine.updateKV(_n_past);
      }

      start.reset();

      State::busy(true);

      while (true) {
        if (State::canceled()) {
          callback("", Sentence::END);
          break;
        }

        if (_n_past + tokens.size() > _ctx->size()) {
          __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
          callback("", Sentence::END);
          break;
        }
        if (!s_engine.process(tokens, logits))
          return Dialog::abort("secondary engine processing failed", callback);

        tokens[0] = _last_tok = sampler.process(logits);

        _n_past++;
        _n_generated++;

        if (!s_engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);

        if (_ctx->is_eos(_last_tok)) {
          callback("", Sentence::END);
          break;
        }

        if (!callback(_tokenizer->decode(tokens), Sentence::CONTINUE)) break;
      }

      State::busy(false);

      _kpis.generate.update(start.elapsed_usec());

      // Log latest KPIs in a single line
      _env->logger().post(Logger::KPIS, kpis().dump(" "));

      return true;
    }

    bool KvShareDialog::convertKV(const fs::path& cache_dir) {
      Timer start;

      fs::path nsp_cache_path = cache_dir / "kv-cache.primary.qnn-htp";
      fs::path cpu_cache_path = cache_dir / "kv-cache.secondary.qnn-cpu";

      __DEBUG("kv-convert: begin converting {} to ", nsp_cache_path.string(), cpu_cache_path.string());

      std::ifstream nsp_fs(nsp_cache_path, std::ios::in | std::ios::binary);

      if (nsp_fs.fail()) {
        __ERROR("kv-convert: error reading file {}", nsp_cache_path.string());
        State::error("failed to read primary kv-cache");
        return false;
      }

      // Read spec from nsp file
      CacheFileSpec nsp_spec;
      nsp_fs.read((char*)&nsp_spec, sizeof(nsp_spec));
      if (nsp_spec.magic != 0xC0DE) {
        __ERROR("kv-convert: expected 0xC0DE found {:#x}", nsp_spec.magic);
        State::error("invalid format of primary kv-cache");
        return false;
      }

      // clang-format off
      __DEBUG("kv-convert: load {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}",
              nsp_spec.num_tensors, nsp_spec.magic, int(nsp_spec.dtype), nsp_spec.n_heads, nsp_spec.embed_dim, nsp_spec.update_size);
      // clang-format on

      std::fstream cpu_fs(cpu_cache_path, std::ios::in | std::ios::out | std::ios::binary);

      if (cpu_fs.fail()) {
        // TODO: replace with proper error handling
        __ERROR("kv-convert: failed to write {}", cpu_cache_path.string());
        State::error("failed to save secondary kv-cache");
        return false;
      }

      CacheFileSpec cpu_spec;
      cpu_fs.read((char*)&cpu_spec, sizeof(cpu_spec));
      if (cpu_spec.magic != 0xC0DE) {
        __ERROR("kv-convert: expected 0xC0DE found {:#x}", cpu_spec.magic);
        State::error("invalid format of secondary kv-cache");
        return false;
      }

      // Set the n_tokens processed during prompt processing and the spec write to file
      cpu_spec.update_size = nsp_spec.update_size;
      cpu_fs.seekp(std::ios::beg);
      cpu_fs.write((char*)&cpu_spec, sizeof(cpu_spec));

      const uint32_t n_layer = nsp_spec.num_tensors / 2;
      const uint32_t n_head  = nsp_spec.n_heads;
      const uint32_t kv_dim  = nsp_spec.embed_dim;
      const uint32_t n_tok   = nsp_spec.update_size;

      const size_t cache_size = n_layer * n_head * kv_dim * n_tok;

      // Read Key/Value Cache
      std::vector<uint8_t> key_cache(cache_size);
      std::vector<uint8_t> value_cache(cache_size);
      nsp_fs.read((char*)key_cache.data(), cache_size);
      nsp_fs.read((char*)value_cache.data(), cache_size);

      // Read Quantization parameters
      std::vector<double> key_scales(n_layer);
      std::vector<double> value_scales(n_layer);
      nsp_fs.read((char*)key_scales.data(), n_layer * sizeof(double));
      nsp_fs.read((char*)value_scales.data(), n_layer * sizeof(double));

      nsp_fs.close();

      // Convert and write on cpu_file
      // Dequant and transpose caches
      const uint32_t layer_size = n_head * kv_dim * n_tok;
      const uint32_t head_size  = kv_dim * n_tok;

      // Transpose kvdim * n_tok (QNN-HTP K$) -> n_tok * kvdim (QNN-CPU K$)
      // For ScopGPT KV$ Format
      __DEBUG("kv-convert: dequantizing keys");
      std::vector<float> dequant_keys(cache_size);
      for (uint32_t i = 0; i < n_layer; i++) {
        for (uint32_t j = 0; j < n_head; j++) {
          for (uint32_t k = 0; k < kv_dim; k++) {
            for (uint32_t l = 0; l < n_tok; l++) {
              // Interleave K$
              // QNN HTP: [0 2 4 ... 126 1 3 5 ... 127]
              // QNN CPU: [0 1 2 ... 63  64 65 ... 127]
              const uint32_t interleaved_k =
                      (2 * k < kv_dim) ? 2 * k : 2 * (k - kv_dim / 2) + 1;

              const uint32_t read_loc  = i * layer_size + j * head_size + k * n_tok  + l;
              const uint32_t write_loc = i * layer_size + j * head_size + l * kv_dim + interleaved_k;

              dequant_keys[write_loc] =
                      (static_cast<float>(key_cache[read_loc]) - 128) * key_scales[i];
            }
          }
        }
      }

      __DEBUG("kv-convert: dequantizing values");
      std::vector<float> dequant_values(cache_size);
      for (uint32_t i = 0; i < n_layer; i++) {
        for (uint32_t j = 0; j < n_head; j++) {
          for (uint32_t l = 0; l < n_tok; l++) {
            for (uint32_t k = 0; k < kv_dim; k++) {
              const uint32_t read_loc  = i * layer_size + j * head_size + l * kv_dim + k;
              const uint32_t write_loc = read_loc;

              dequant_values[write_loc] =
                      (static_cast<float>(value_cache[read_loc]) - 128) * value_scales[i];
            }
          }
        }
      }

      __DEBUG("kv-convert: storing converted KV to file");
      cpu_fs.write((char *)dequant_keys.data(), dequant_keys.size() * sizeof(float));
      cpu_fs.write((char *)dequant_values.data(), dequant_values.size() * sizeof(float));

      cpu_fs.flush();
      cpu_fs.close();

      __DEBUG("kv-convert: done converting {} to {} in {} usec",
              nsp_cache_path.string(),
              cpu_cache_path.string(),
              start.elapsed_usec());

      return true;

    }

// Registrator instance
    static OnLoad regy([]() {
        Dialog::__register(
                "kv-share",
                [](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
                    return (Dialog*)new KvShareDialog(env, name, conf);
                }
        );
    });

    void needKvShareDialog() {}

} // namespace qualla
