//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#include "qualla/env.hpp"
#include "qualla/detail/timer.hpp"
#include "qualla/detail/cache-file.hpp"

#include "fmt/format.h"
#include "fmt/ranges.h"

#include "qnn-utils.hpp"
#include "cpu-model.hpp"

#include <set>
#include <cstring>
#include <fstream>
#include <sstream>
#include <cassert>

namespace fs = std::filesystem;

#define __INFO(__fmt, ...)  _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
#define __WARN(__fmt, ...)  _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
#define __KPIS(__fmt, ...)                                                                         \
    _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __DEBUG(__fmt, ...)                                                                        \
    _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __TRACE(__fmt, ...)                                                                        \
    _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })

namespace qualla {

QnnCpuModel::QnnCpuModel(Env& env, const Params& params)
    : _env(env), model_basedir(params.model_basedir), op_package(params.op_package),
      backend_lib(params.backend_lib), model_bin_path(params.model_bin_path), model(params.model),
      m_ctx_size(params.ctx_size), m_num_threads(params.n_threads), m_num_tokens(params.ctx_size),
      m_numLogits(params.n_logits), m_vocab_size(params.n_vocab_size), m_num_layer(params.n_layer),
      m_embd(params.n_embd), m_num_heads(params.n_heads), m_use_mmap(params.use_mmap),
      model_output(params.model_output) {
    // Initialize QnnAPI
    m_qnnApi = std::unique_ptr<QnnApi>(new QnnApi());
    m_head_dim = m_embd / m_num_heads;
    m_input_dim.push_back(1);
    m_input_dim.push_back(m_ctx_size);
    // K$, V$ 4D Tensor {n_layer, n_heads, n_ctx, n_head_dim}
    m_kv_dim.push_back(m_num_layer);
    m_kv_dim.push_back(m_num_heads);
    m_kv_dim.push_back(m_ctx_size + 1);
    m_kv_dim.push_back(m_head_dim);
    if (model_output == ModelOutput::LOGITS) {
        m_output_dim.push_back(m_numLogits);
        m_output_dim.push_back(m_vocab_size);
    } else if (model_output == ModelOutput::EMBEDDINGS) {
        m_numLogits = m_ctx_size;
        m_output_dim.push_back(m_numLogits);
        m_output_dim.push_back(m_embd);
    }
    m_loraConfigType = params.lora_config_type;
    m_lora_alpha_val = 1.0f;

    if (m_loraConfigType == LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE) {
      m_loraConfig.insert(params.lora_config.begin(), params.lora_config.end());
    }
}

QnnCpuModel::~QnnCpuModel() {
    // Free Qnn Tensor and their memory
    auto start = std::chrono::steady_clock::now();
    if (dequant_logits_ptr != nullptr) free(dequant_logits_ptr);
    if (m_ioTensor) {
        QNN_DEBUG("Tearing Down Input Tensors Bank");
        for (auto& graph_name : model_order) {
            m_ioTensor->tearDownTensors(
                    m_input_tensors[graph_name], m_input_specs[graph_name].size()
            );
            m_ioTensor->tearDownTensors(
                    m_output_tensors[graph_name], m_output_specs[graph_name].size()
            );
        }
    }
    auto stop = std::chrono::steady_clock::now();
    //QnnUtils::logProfile("Model destruction (cpp) took", start, stop);
}

// Given a filename, initializeModel load and initializes QNN runtime libraries and the model
bool QnnCpuModel::initializeModel(void) {
    // prepare params
    Qnn_Param_t params[5];
    params[0].paramType               = QNN_PARAMTYPE_SCALAR;
    params[0].name                    = (char*)("model_bin_path");
    params[0].scalarParam.dataType    = QNN_DATATYPE_STRING;
    params[0].scalarParam.stringValue = model_bin_path.c_str();

    params[1].paramType               = QNN_PARAMTYPE_SCALAR;
    params[1].name                    = (char*)("num_thread");
    params[1].scalarParam.dataType    = QNN_DATATYPE_UINT_32;
    params[1].scalarParam.uint32Value = m_num_threads;

    params[2].paramType               = QNN_PARAMTYPE_SCALAR;
    params[2].name                    = (char*)("num_context");
    params[2].scalarParam.dataType    = QNN_DATATYPE_UINT_32;
    params[2].scalarParam.uint32Value = m_ctx_size;

    params[3].paramType               = QNN_PARAMTYPE_SCALAR;
    params[3].name                    = (char*)("num_last_logits");
    params[3].scalarParam.dataType    = QNN_DATATYPE_UINT_32;
    params[3].scalarParam.uint32Value = m_numLogits;

    params[4].paramType               = QNN_PARAMTYPE_SCALAR;
    params[4].name                    = (char*)("use_mmap");
    params[4].scalarParam.dataType    = QNN_DATATYPE_BOOL_8;
    params[4].scalarParam.uint32Value = m_use_mmap;

    if (true != m_qnnApi->initialize(
                        backend_lib,
                        model,
                        op_package,
                        ContextConfigs(),
                        {},
                        m_input_dim.data(),
                        m_input_dim.size(),
                        m_output_dim.data(),
                        m_output_dim.size(),
                        m_kv_dim.data(),
                        m_kv_dim.size(),
                        params,
                        5,
                        false
                )) {
        QNN_ERROR("Backend library : %s", backend_lib.c_str());
        throw std::runtime_error("QNN initialization failed!");
    }

    // Initialize QNN IO Tensor
    m_ioTensor   = std::unique_ptr<IOTensor>(new IOTensor());
    m_num_graphs = m_qnnApi->getGraphsCount();
    QNN_DEBUG("QNN initialized with %u graph(s)", m_num_graphs);

    auto graphs_info = m_qnnApi->getGraphsInfo();
    for (size_t graph_idx = 0; graph_idx < m_num_graphs; graph_idx++) {
        GraphInfo_t* const& graph_info = graphs_info[graph_idx];
        char*               graph_name = graph_info->graphName;
        std::string         graph_str  = std::string(graph_name);

        QNN_DEBUG("Loaded graph[%lu] = %s", graph_idx, graph_name);
        model_order.push_back(graph_str);
        model_context[graph_str] =
                m_qnnApi->getContexts()[graph_idx / m_qnnApi->getGraphCountPerContext()];
    }

    // CPU support KV cache mode
    m_mode = ExecutionMode::KV_ONLY;

    return true;
}

// Once the model has been loaded, initialize IO Tensors
// m_ioTensors is initialized by the context for now
bool QnnCpuModel::initializeIOTensors() {
    QNN_DEBUG("Create input tensors bank");

    // Ideally, we should create and initalize m_ioTensor for each context, but we want to
    // be able to see/use all the buffers in every contexts so that they can be connected
    // with each other. Hence, we are using only the first context to initialize the m_ioTensor
    // and use it for all graphs/contexts.
    if (true != m_ioTensor->initialize(m_qnnApi->getContexts()[0])) {
        QNN_ERROR("Failure to initialize IOTensor");
        return false;
    }

    // Getting graph info and its count needed for subsequent steps
    GraphInfo_t** const& graphsInfo = m_qnnApi->getGraphsInfo();

    for (size_t graphIdx = 0; graphIdx < m_num_graphs; graphIdx++) {
        GraphInfo_t* const& graphInfo = graphsInfo[graphIdx];
        std::string         graphName = std::string(graphInfo->graphName);

        // Setup Inputs
        {
            std::unordered_map<std::string, size_t> inputTensorsSize;
            for (size_t tensorIdx = 0; tensorIdx < graphInfo->numInputTensors; tensorIdx++) {
                std::string         tensor_name;
                std::vector<size_t> tensorDims;

                auto& tensor = graphInfo->inputTensors[tensorIdx];
                m_qnnApi->getTensorNameAndShape(tensor_name, tensorDims, tensor);
                std::vector<QnnUtils::QuantParam> quantParams;
                if (!m_qnnApi->getTensorQuantParams(&tensor, quantParams)) {
                    QNN_DEBUG("Couldn't get tensor quant params : %s", tensor_name.c_str());
                    quantParams.emplace_back(0, 0);
                }

                auto dims                     = QnnUtils::Dims(tensorDims);
                inputTensorsSize[tensor_name] = dims.getAlignedSize();

                m_input_specs[graphName][tensor_name] = {&tensor, dims, quantParams};
            }

            Qnn_Tensor_t*                          tensor_bank = nullptr;
            std::unordered_map<std::string, void*> tensor_ptr_map;
            if (true != m_ioTensor->setupInputTensors(
                                &tensor_bank,
                                tensor_ptr_map,
                                *graphInfo,
                                inputTensorsSize,
                                m_qnnApi->getContexts()[graphIdx],
                                false
                        )) {
                QNN_ERROR("Error in setting up Input Tensors for graph %s", graphName.c_str());
                return false;
            }

            m_input_tensors[graphName] = tensor_bank;
            for (auto& [tensor_name, tensor_ptr] : tensor_ptr_map) {
                m_input_specs[graphName][tensor_name].tensor = (Qnn_Tensor_t*)tensor_ptr;
            }
        }

        // Setup Outputs
        {
            std::unordered_map<std::string, size_t> outputTensorsSize;
            for (size_t tensorIdx = 0; tensorIdx < graphInfo->numOutputTensors; tensorIdx++) {
                std::string         tensor_name;
                std::vector<size_t> tensorDims;

                auto& tensor = graphInfo->outputTensors[tensorIdx];
                m_qnnApi->getTensorNameAndShape(tensor_name, tensorDims, tensor);
                std::vector<QnnUtils::QuantParam> quantParams;
                if (!m_qnnApi->getTensorQuantParams(&tensor, quantParams)) {
                    QNN_DEBUG("Couldn't get tensor quant params : %s", tensor_name.c_str());
                    quantParams.emplace_back(0, 0);
                }

                auto dims                      = QnnUtils::Dims(tensorDims);
                outputTensorsSize[tensor_name] = dims.getAlignedSize();

                m_output_specs[graphName][tensor_name] = {&tensor, dims, quantParams};
            }

            Qnn_Tensor_t*                          tensor_bank = nullptr;
            std::unordered_map<std::string, void*> tensor_ptr_map;
            if (true != m_ioTensor->setupOutputTensors(
                                &tensor_bank,
                                tensor_ptr_map,
                                *graphInfo,
                                outputTensorsSize,
                                m_qnnApi->getContexts()[graphIdx],
                                false
                        )) {
                QNN_ERROR("Error in setting up Output Tensors for graph %s", graphName.c_str());
                return false;
            }

            m_output_tensors[graphName] = tensor_bank;
            for (auto& [tensor_name, tensor_ptr] : tensor_ptr_map) {
                m_output_specs[graphName][tensor_name].tensor = (Qnn_Tensor_t*)tensor_ptr;
            }
        }
    }

#ifdef DUMP_TENSOR_SPECS
    dumpTensorSpecs();
#endif

    auto stop = std::chrono::steady_clock::now();
    //QnnUtils::logProfile("initializeIoTensors (cpp) took", start, stop);

    return true;
}

void QnnCpuModel::dumpTensorSpecs() {
#ifdef DEBUG_DUMP_TARGET_PATH
    if (true != QnnUtils::CreateDirsIfNotExist(DEBUG_DUMP_TARGET_PATH)) {
        throw std::runtime_error(
                std::string("Could not create directory : ") + DEBUG_DUMP_TARGET_PATH
        );
    }

    static const char* stringFmt =
            "\t\t{ \"name\": \"%s\", \"dims\": [1, %d, %d, %d], \"bitwidth\": %d, \"scale\": [%s], \"offset\": [%s] },\n";

    GraphInfo_t** const& graphsInfo = m_qnnApi->getGraphsInfo();
    for (size_t graphIdx = 0; graphIdx < m_num_graphs; graphIdx++) {
        GraphInfo_t* const& graphInfo = graphsInfo[graphIdx];
        std::string         graphName = std::string(graphInfo->graphName);

        // Create output spec file and open it
        char filename[255];
        sprintf(filename, "%s/spec.%s.json", DEBUG_DUMP_TARGET_PATH, graphInfo->graphName);

        FILE* specFile = fopen(filename, "w");
        if (specFile == NULL) {
            throw std::runtime_error(std::string("Error opening file : ") + filename);
        }

        fprintf(specFile, "{\n\t\"graph_name\" : \"%s\",\n\t\"inputs\" : [\n", graphName.c_str());

        std::string         tensor_name;
        std::vector<size_t> tensorDims;

        for (size_t tensorIdx = 0; tensorIdx < graphInfo->numInputTensors; tensorIdx++) {
            auto& tensor = graphInfo->inputTensors[tensorIdx];
            m_qnnApi->getTensorNameAndShape(tensor_name, tensorDims, tensor);
            std::string fixed_tensor_name = tensor_name.substr(0, tensor_name.find("_converted"));
            QnnUtils::Tensor& spec        = m_input_specs[graphName][fixed_tensor_name];
            std::string       scales;
            std::string       offsets;
            getQuantParamString(spec.quantParam, scales, offsets);
            fprintf(specFile,
                    stringFmt,
                    tensor_name.c_str(),
                    spec.dims.height,
                    spec.dims.width,
                    spec.dims.channel,
                    spec.dims.bitWidth,
                    scales.c_str(),
                    offsets.c_str());
        }

        fseek(specFile, -2, SEEK_CUR); // Remove trailing comma

        // Dump out output tensor specs
        fprintf(specFile, "\n\t],\n\t\"outputs\" : [\n");

        for (size_t tensorIdx = 0; tensorIdx < graphInfo->numOutputTensors; tensorIdx++) {
            auto& tensor = graphInfo->outputTensors[tensorIdx];
            m_qnnApi->getTensorNameAndShape(tensor_name, tensorDims, tensor);
            std::string fixed_tensor_name = tensor_name.substr(0, tensor_name.find("_converted"));
            QnnUtils::Tensor& spec        = m_output_specs[graphName][fixed_tensor_name];
            std::string       scales;
            std::string       offsets;
            getQuantParamString(spec.quantParam, scales, offsets);
            fprintf(specFile,
                    stringFmt,
                    tensor_name.c_str(),
                    spec.dims.height,
                    spec.dims.width,
                    spec.dims.channel,
                    spec.dims.bitWidth,
                    scales.c_str(),
                    offsets.c_str());
        }
        fseek(specFile, -2, SEEK_CUR); // Remove trailing comma
        fprintf(specFile, "\n\t]\n}");

        fclose(specFile);
    }
#else
    QNN_ERROR(
            "Requested dump tensor specs, but DEBUG_DUMP_TARGET_PATH not set. Please check nsp-model.h"
    );
#endif
}

template <bool PrintError = true, typename ValType>
inline bool findTensor(std::unordered_map<std::string, ValType>& map, std::string key) {
    if (map.find(key) == map.end()) {
        if constexpr (PrintError == true) QNN_ERROR("Cannot find %s\n", key.c_str());
        return false;
    }
    return true;
}

template <bool PrintError = false, typename ValType>
inline ValType* getTensor(std::unordered_map<std::string, ValType>& map, std::string key) {
    if (map.find(key) == map.end()) {
        if constexpr (PrintError == true) QNN_ERROR("Cannot find %s\n", key.c_str());
        return nullptr;
    }
    return &map[key];
}

// Run all validations for the model here so we can exit early
bool QnnCpuModel::validateModel() {
    return true;
}

bool QnnCpuModel::initializeTensorPointers() {
    auto& input_specs         = m_input_specs[model_order.back()];
    t_input_ids               = &input_specs["x0"];
    t_input_ids_num_token     = &input_specs["x1"];
    t_input_ids_reset_kvcache = &input_specs["x2"];
    t_input_ids_k_cache       = &input_specs["x3"];
    t_input_ids_v_cache       = &input_specs["x4"];
    t_input_ids_n_past        = &input_specs["x5"];
    t_input_lora_alpha        = &input_specs["x6"];

    auto& output_specs = m_output_specs[model_order.back()];
    t_logits           = &output_specs["output_genAI"];
    t_output_n_past    = &output_specs["output_npast"];
    return true;
}

void QnnCpuModel::setupInputTensors(const std::vector<int32_t>& tokens, bool run_bert_mode) {
    auto start = std::chrono::steady_clock::now();

    size_t num_tokens = m_num_tokens;

    if (tokens.size() > num_tokens) {
        std::string err_msg = "Called inference with more tokens than model supports: ";
        err_msg += std::to_string(tokens.size()) + " vs. " + std::to_string(num_tokens);
        throw std::runtime_error(err_msg);
    }

    // Grab pointers to buffers for access
    uint32_t* input_id_buffer               = (uint32_t*)getBuffer(t_input_ids);
    uint32_t* input_id_num_token_buffer     = (uint32_t*)getBuffer(t_input_ids_num_token);
    uint32_t* input_id_reset_kvcache_buffer = (uint32_t*)getBuffer(t_input_ids_reset_kvcache);
    uint32_t* input_id_n_past_buffer        = (uint32_t*)getBuffer(t_input_ids_n_past);
    float*    input_id_lora_alpha           = (float*)getBuffer(t_input_lora_alpha);

    uint32_t size = 1;
    for (auto dim : m_input_dim) {
        size *= dim;
    }

    std::memset(input_id_buffer, 0, size * sizeof(uint32_t));
    std::memset(input_id_n_past_buffer, 0, sizeof(uint32_t));
    std::memset(input_id_num_token_buffer, 0, sizeof(uint32_t));
    std::memset(input_id_reset_kvcache_buffer, 0, sizeof(uint32_t));

    std::memcpy(input_id_buffer, tokens.data(), tokens.size() * sizeof(uint32_t));
    *input_id_num_token_buffer = tokens.size();
    *input_id_n_past_buffer = m_nPast;
    *input_id_lora_alpha = m_lora_alpha_val;

    auto stop = std::chrono::steady_clock::now();
    // QnnUtils::logProfile("setupInputTensors (cpp) took", start, stop);
}

// Use qnnAPI to execute the model
template <class T1, class T2>
inline bool QnnCpuModel::executeModel(T1& input, T2& output, std::string graph_name) {
    // given that a dnn instance is created and we have input loaded with image data we can get our output
    // for our required app functionality Execute the network with the given single input.
    QNN_DEBUG("Now executing inference for graph %s", graph_name.c_str());

#ifdef INPUT_DUMP
    if (m_inference_count < 5) dumpTensors(graph_name, true); // Dump input tensors
#endif

    bool ret = m_qnnApi->graphExecute(input, output, graph_name, timeLogs);

    if (ret != true) {
        QNN_ERROR("ERROR executing inference: %d for graph %s", ret, graph_name.c_str());
        return false;
    }
#ifdef OUTPUT_DUMP
    if (m_inference_count < 5) dumpTensors(graph_name, false); // Dump output tensors
#endif
    QNN_DEBUG("Execute finished for graph %s", graph_name.c_str());

    return true;
}

bool QnnCpuModel::runInferenceHelper(
        std::vector<std::string>& exec_models,
        int32_t*                  wait_time_total,
        int32_t*                  exec_time_total,
        bool                      pipeline_kv_update,
        size_t                    update_size
) {
    int32_t exec_time = 0;
    int32_t wait_time = 0;
    for (auto& graph_name : exec_models) {
        {
            auto startTime = std::chrono::steady_clock::now();
            if (true !=
                executeModel(m_input_tensors[graph_name], m_output_tensors[graph_name], graph_name))
                return false;
            auto endTime = std::chrono::steady_clock::now();
            exec_time += static_cast<int32_t>(
                    std::chrono::duration_cast<std::chrono::microseconds>(endTime - startTime)
                            .count()
            );
        }
    }

    if (pipeline_kv_update) {
        m_nPast += update_size;
    }

    *exec_time_total = exec_time;
    *wait_time_total = wait_time;
    return true;
}

bool QnnCpuModel::runInference(const std::vector<int32_t>& tokens, bool logits_all) {
    __DEBUG("qnn-cpu: run-inference start : n_tokens {}", tokens.size());

    auto start = std::chrono::steady_clock::now();

    // Technical note: int32_t can hold upto 596 hours
    // Even int16_t should be sufficient here - it holds upto 32.8 seconds
    int32_t total_wait_time = 0;
    int32_t total_exec_time = 0;

    // Setup inputs for inference
    setupInputTensors(tokens, false);

    auto& exec_models = model_order;
    if (!runInferenceHelper(exec_models, &total_wait_time, &total_exec_time, false, tokens.size()))
        return false;

    prev_run.num_tokens_processed = tokens.size();
    m_inference_count++;

    prev_run.was_bert_mode  = false;
    prev_run.was_logits_all = logits_all;

    auto stop = std::chrono::steady_clock::now();
    //QnnUtils::logProfile("Run Inference (cpp) took", start, stop);
    timeLogs["Run Inference (cpp) "].first += static_cast<double>(
            std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count()
    );
    timeLogs["Run Inference (cpp) "].second++;
    QNN_DEBUG("[TIME] Wait[%d] Exec[%d]\n", total_wait_time, total_exec_time);
    return true;
}

void QnnCpuModel::printFinalLogs() {
#if NSP_LOG_LEVEL > 1
    QNN_DEBUG("Total inference count : %d", m_inference_count);
    for (auto& [key, value] : timeLogs) {
        QNN_DEBUG("%s : %lf", key.c_str(), value.first / value.second);
    }
#endif
}

bool QnnCpuModel::setKVCacheNPast(size_t n_past) {
    if(n_past > m_nPast) {
        size_t num_update = n_past - m_nPast;
        if (n_past != 0 && num_update > prev_run.num_tokens_processed) {
            std::string err_msg = "Requested larger n_past update than #tokens produced by model";
            err_msg += std::to_string(num_update) + " vs. " + std::to_string(m_num_tokens);
            throw std::runtime_error(err_msg);
        }
    }

    m_nPast = n_past;
    return true;
}

size_t QnnCpuModel::getDequantLogits(std::vector<float>& dequant_logits, bool logits_all) {
    // if model is BERT, always return ALL logits
    if (model_output == ModelOutput::EMBEDDINGS)
        logits_all = true;

    __DEBUG("qnn-cpu: get-dequant-logits logits_all {}", logits_all);

    auto&  logit_spec = m_output_specs[model_order.back()]["output_genAI"];
    float* logitBuf   = (float*)getBuffer(logit_spec);
    size_t offset = 0;
    dequant_logits.clear();
    if (model_output == ModelOutput::LOGITS) {
        // if logits_all return [m_numLogits * m_vocab_size] else return [1 * m_vocab_size]
        if (!logits_all) {
            // Return the last processed token logits i.e. [ ..., [1]]
            if (m_numLogits > 1) {
                offset = (m_numLogits - 1) * m_vocab_size;
            }
        } else {
            // if m_numLogits > n_tokens_processed, it is left padded, [0, 0, [n_tokens_processed]]
            // calculate offset for getting the appropriate logits
            if (m_numLogits >= prev_run.num_tokens_processed) {
                offset = (m_numLogits - prev_run.num_tokens_processed) * m_vocab_size;
            }
        }
    }
#ifdef DUMP_LOGITS
    {
        char fname[255];
        sprintf(fname, "%s/logits/%03d", DEBUG_DUMP_TARGET_PATH, m_inference_count);
        QnnUtils::writeRawData(getBuffer(logit_spec), getBufferSize(logit_spec), fname);
    }
#endif
    if (model_output == ModelOutput::LOGITS) {
        // logits size = [m_numLogits * m_vocab_size]
        // logits might be left padded so, use calculated offset
        dequant_logits.reserve((getBufferSize(logit_spec) - (offset * sizeof(float))));
        for (auto i = offset; i < (getBufferSize(logit_spec) / sizeof(float)); ++i) {
            dequant_logits.push_back(logitBuf[i]);
        }
    } else if (model_output == ModelOutput::EMBEDDINGS) {
        // embeddings size = [n_tokens_processed * m_embd]
        dequant_logits.reserve((prev_run.num_tokens_processed * m_embd * sizeof(float)));
        for (auto i = offset; i < ((prev_run.num_tokens_processed * m_embd)); ++i) {
            dequant_logits.push_back(logitBuf[i]);
        }
    }

    return logits_all? prev_run.num_tokens_processed : 1;
}

bool QnnCpuModel::applyBinarySections(std::vector<std::string>& binsection_list) {
  //apply binary section for lora config
  for (int i = 0; i < binsection_list.size(); i++) {
    __DEBUG("qnn-cpu: applyBinarySections adapters {}", binsection_list.at(i));
    if (!m_qnnApi->applyBinarySection(i, binsection_list.at(i))) {
      __ERROR("qnn-cpu: Error in applyBinarySections {}", i);
      return false;
    }
  }
  return true;
}

bool QnnCpuModel::applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val) {
  m_lora_alpha_val = alpha_val;
  return true;
}

bool QnnCpuModel::applyLoraAdapter(const std::string& lora_adapter_name) {
  if (m_loraConfigType != LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE) {
    __ERROR("qnn-cpu: Lora config is not enable for adapters");
    return false;
  }

  if (!m_loraConfig.contains(lora_adapter_name)) {
    __ERROR("qnn-cpu: Could not find lora adapters config to apply ");
    return false;
  }
  if (!applyLoraStrength(
          m_loraConfig[lora_adapter_name].alpha_tensor_name,
          m_loraConfig[lora_adapter_name].alpha_tensor_val
          )) {
    __ERROR("qnn-cpu: Could not apply Alpha tensor ");
    return false;
  }

  if (!applyBinarySections(m_loraConfig[lora_adapter_name].binsection_list)) {
    __ERROR("qnn-cpu: Could not apply binary Sections ");
    return false;
  }
  return true;
}

// TODO: implement save/restore
size_t QnnCpuModel::loadKVCache(const std::string& load_path) {
    //TO read the cache file into KV tensor
    std::ifstream f(load_path, std::ios::in | std::ios::binary);
    if (f.fail()) {
        // TODO: replace with proper error handling
        __ERROR("qnn-cpu: load-kv errror reading file {}", load_path);
        return 0;
    }

    CacheFileSpec spec;
    f.read((char*)&spec, sizeof(spec));
    if (spec.magic != 0xC0DE) {
        __ERROR("qnn-cpu: load-kv expected 0xC0DE found {:#x}", spec.magic);
        return 0;
    }
    // clang-format off
    __DEBUG("qnn-cpu: load-kv {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}",
    spec.num_tensors, spec.magic, int(spec.dtype), spec.n_heads, spec.embed_dim, spec.update_size);
    // clang-format on

    const int32_t n_valid = static_cast<int32_t>(spec.update_size);

    float* input_id_k_cache_buffer = (float*)getBuffer(t_input_ids_k_cache);
    float* input_id_v_cache_buffer = (float*)getBuffer(t_input_ids_v_cache);

    // K$, V$ 4D Tensor {n_layer, n_heads, n_ctx, n_head_dim}

    const size_t copy_size = n_valid * m_head_dim;
    const size_t skip_size = (m_ctx_size + 1) * m_head_dim;

    for (int i = 0; i < m_num_layer; i++) {
        for(int j = 0; j < m_num_heads; j++) {
            f.read((char*)input_id_k_cache_buffer, copy_size * sizeof(float));
            input_id_k_cache_buffer += skip_size;
        }
     }

    for (int i = 0; i < m_num_layer; i++) {
        for(int j = 0; j < m_num_heads; j++) {
            f.read((char*)input_id_v_cache_buffer, copy_size * sizeof(float));
            input_id_v_cache_buffer += skip_size;
         }
     }

     f.close();

     m_nPast = n_valid;
     prev_run.num_tokens_processed = m_nPast;
     return spec.update_size;
}

bool QnnCpuModel::saveKVCache(const std::string& save_path) {
    __DEBUG("qnn-cpu: save-kv path {}", save_path);

    std::ofstream f(save_path, std::ios::out | std::ios::binary);
    if (f.fail()) {
        __ERROR("qnn-cpu: save-kv error opening file : {}", save_path);
        throw std::runtime_error("Failed to write to cache file. Please re-check path");
    }

    const uint32_t  n_valid = static_cast<uint32_t>(m_nPast);
    const CacheFileSpec::DataType dtype   = CacheFileSpec::DataType::FLOAT32_T;

   // Save the cache file metadata
    CacheFileSpec spec(m_num_layer * 2, 0xc0de, dtype, 0x0, m_num_heads, m_head_dim, n_valid);
    f.write((char*)&spec, sizeof(spec)); // as nsp already updated the spec
    if(n_valid > 0) {
        // Dump KeyCache and ValueCache
        float* input_id_k_cache_buffer = (float*)getBuffer(t_input_ids_k_cache);
       float* input_id_v_cache_buffer = (float*)getBuffer(t_input_ids_v_cache);

        // K$, V$ 4D Tensor {n_layer, n_heads, n_ctx, n_head_dim}

        const size_t copy_size = n_valid * m_head_dim;
       const size_t skip_size = (m_ctx_size + 1) * m_head_dim;
        for (int i = 0; i < m_num_layer; i++) {
            for(int j = 0; j < m_num_heads; j++) {
                f.write((char*)input_id_k_cache_buffer, copy_size * sizeof(float));
                input_id_k_cache_buffer += skip_size;
            }
        }

        for (int i = 0; i < m_num_layer; i++) {
            for(int j = 0; j < m_num_heads; j++) {
                f.write((char*)input_id_v_cache_buffer, copy_size * sizeof(float));
                input_id_v_cache_buffer += skip_size;
            }
        }
    }

    f.flush();
    f.close();

    return true;
}

} // namespace qualla
