//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#define _USE_MATH_DEFINES // Used for M_PI

#include "qualla/env.hpp"
#include "qualla/detail/timer.hpp"
#include "qualla/detail/cache-file.hpp"

#include "fmt/format.h"
#include "fmt/ranges.h"
#include "fmt/os.h"
#include <iostream>
#include "nsp-model.hpp"

#include <set>
#include <cstring>
#include <fstream>
#include <sstream>
#include <cassert>
#include <cstdio>
#include <span>
#include "fp16/fp16.h"

namespace fs = std::filesystem;

#define __INFO(__fmt, ...)  _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
#define __WARN(__fmt, ...)  _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
#define __KPIS(__fmt, ...)                                                                         \
    _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __DEBUG(__fmt, ...)                                                                        \
    _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
#define __TRACE(__fmt, ...)                                                                        \
    _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })

namespace qualla {

QnnNspModel::QnnNspModel(Env& env, const Params& params)
    : _env(env), model_basedir(params.model_basedir) {
    // Initialize QnnAPI
    m_qnnApi = std::unique_ptr<QnnApi>(new QnnApi());

    spill_fill_buffer_size = params.spill_fill_bufsize;
    m_kv_dim               = params.kv_dim;
    m_use_mmap             = params.use_mmap;
    m_use_async_Init       = params.use_async_Init;
    mmap_budget            = params.mmap_budget;
    m_ctx_size             = params.ctx_size;
    m_pad_token            = params.pad_token;
    lmhead_weight_dir      = params.lmhead_weight_dir;
    graph_switching        = params.graph_switching;
    load_select_graphs     = params.load_select_graphs;
    lora_conf              = params.lora_config_type;
    embedding_length       = params.embedding_length;
    embedding_datatype     = params.embedding_datatype;
    m_disableKvCache       = params.disable_kv_cache;
    m_embd_size            = params.n_embd;
    m_modelArchitectureType = params.modelArchitectureType;
    // Positional encoding parameters
    m_positional_encoding = params.positional_encoding_params;
    if (m_positional_encoding.type == PositionalEncoding::ROPE) // Save m_pos_dim for easy access
        m_pos_dim = m_positional_encoding.rope_params.dims;

    // Debug flags
    _debug_path    = params.debug_path;
    _debug_specs   = params.debug_specs;
    _debug_tensors = params.debug_tensors;
    _debug_outputs = params.debug_outputs;
    _debug_qnn     = params.debug_qnn;

    _backend_lib      = params.backend_lib;
    _backend_ext_conf = params.backend_ext_conf;

    if (graph_switching && !m_use_mmap)
        __WARN("Graph switching with non-mmaped implementation can cause high sustained memory usage"
        );

    variant_latency = params.variant_latency;

    if(m_modelArchitectureType == ModelArchitectureType::ENCODER){
        m_pooled_output             = params.pooled_output;
    }

    exec_select_graphs = params.exec_select_graphs;
    if (!exec_select_graphs.empty())
        __DEBUG("qnn-htp : Execute selected graphs = {}", exec_select_graphs);

    _kv_update_method = (params.kv_update_method == "POINTER_SHIFT") ? POINTER_SHIFT : SHIFT_CONCAT;
    __DEBUG("qnn-htp : NSP KV$ Update Method = {}",
            (_kv_update_method == POINTER_SHIFT) ? "POINTER_SHIFT" : "SHIFT_CONCAT");

    // Set up filename list.
    for (auto& i : params.model_list) {
        fs::path model_path = fs::path(i);
        if (model_path.is_relative()) model_path = model_basedir / fs::path(i);
        if (!fs::is_regular_file(model_path)) {
            __ERROR("NSPModel: Can't access model file : {}", model_path.string());
            throw std::runtime_error("NSPModel: Can't access model file : " + model_path.string());
        }
        model_filelist.push_back(model_path.string());
    }

    if (lora_conf != LoraConfigType::LORA_DISABLE) {
        lora_config.insert(params.lora_param.begin(), params.lora_param.end());
    }

    if (params.n_threads > 0) {
        _threaded = true;
        _cpumask  = params.cpumask;
        __DEBUG("qnn-htp: starting threadpool : n_threads {} params. {:#x} poll {}",
                params.n_threads,
                _cpumask,
                params.poll);
        threadpool.start(params.n_threads, _cpumask, params.poll);
    }

    // Initialize QNN IO Tensor
    m_ioTensor = std::unique_ptr<IOTensor>(new IOTensor(
            m_sharedBuffer ? BufferAlloc::SHARED_BUFFER : BufferAlloc::DEFAULT,
            m_sharedBuffer ? m_qnnApi->getQnnInterfaceVer() : nullptr
    ));

    m_qnnApi->setIOTensorBufferMgr(m_ioTensor.get());
    m_qnnApi->setKVDim(m_kv_dim);
    m_qnnApi->setContextSize(m_ctx_size);
    m_qnnApi->setKVUpdateMethod(_kv_update_method);

    if (params.debug_specs || params.debug_tensors) {
        if (!fs::exists(params.debug_path) && !fs::create_directories(params.debug_path))
            throw std::runtime_error("Could not create debug directory : " + params.debug_path);
    }
}

QnnNspModel::~QnnNspModel() {
    qualla::Timer start;

    if (_threaded) {
        __DEBUG("qnn-htp: stopping threadpool");
        threadpool.stop(); // Stop Threadpool first
    }

    // Free cached RoPE memory
    if (rope_sin != nullptr) free(rope_sin);
    if (rope_cos != nullptr) free(rope_cos);

    __DEBUG("qnn-htp: model destruct complete: {} usec", start.elapsed_usec());
}
bool QnnNspModel::float32ToFloat16(uint8_t *out, float *in, size_t numElements) {
  if(!numElements) return false;
  uint16_t *temp = (uint16_t *)out;
  for(size_t i = 0; i < numElements; i++){
    temp[i] = fp16_ieee_from_fp32_value(in[i]);
  }
  return true;
}
// Given a filename, initializeModel load and initializes QNN runtime libraries and the model
bool QnnNspModel::initializeModel(void) {
    qualla::Timer start;

    __DEBUG("qnn-htp: model init start");

    // Default backends
#ifdef _WIN32
    const std::string m_backend                = _backend_lib.empty() ? "QnnHtp.dll" : _backend_lib;
    const std::string m_systemLib              = "QnnSystem.dll";
    const std::string backendExtensionsLibPath = "QnnHtpNetRunExtensions.dll";
#else
    const std::string m_backend   = _backend_lib.empty() ? "libQnnHtp.so" : _backend_lib;
    const std::string m_systemLib = "libQnnSystem.so";
    const std::string backendExtensionsLibPath = "libQnnHtpNetRunExtensions.so";
#endif
#ifdef QUALLA_INTERNAL_QNN_SDK
    if (_backend_ext_conf.empty()) {
        __INFO("No backend extension config provided");
    }
    fs::path m_backendExtensionsConfigPath = fs::path(_backend_ext_conf);
#else
    fs::path m_backendExtensionsConfigPath =
            _backend_ext_conf.empty() ? fs::path("data") / "htp_backend_ext_config.json"
                                      : fs::path(_backend_ext_conf);

    if (m_backendExtensionsConfigPath.is_relative())
        m_backendExtensionsConfigPath = fs::path(model_basedir) / m_backendExtensionsConfigPath;

    if (!fs::is_regular_file(m_backendExtensionsConfigPath)) {
        __ERROR("Cannot access {}", m_backendExtensionsConfigPath.string());
        return false;
    }
#endif
    __INFO("Backend library : {}", m_backend);
    __INFO("System library  : {}", m_systemLib);
    __INFO("Model dir   : {}", model_basedir.string());
    __INFO("Model files : {}", model_filelist);
    __INFO("Backend extensions lib path : {}", backendExtensionsLibPath);
    __INFO("Backend extensions config path : {}", m_backendExtensionsConfigPath.string());

    if (!m_qnnApi->initialize(
                m_backend,
                model_filelist,
                BackendExtensionsConfigs(
                        backendExtensionsLibPath, m_backendExtensionsConfigPath.string()
                ),
                PerfProfile::BURST,
                ContextConfigs(Qnn_Priority_t::QNN_PRIORITY_DEFAULT),
                {},          // graphConfigs
                true,        // loadFromCachedBinary
                m_systemLib, // systemLibraryPath
                false,
                spill_fill_buffer_size,
                m_use_mmap,
                m_use_async_Init,
                mmap_budget,
                _debug_qnn,
                graph_switching,
                exec_select_graphs,
                load_select_graphs
        )) {
        __ERROR("qnn-api initialization failed!");
        return false;
    }

    int32_t n_splits = 0;
    m_num_graphs     = m_qnnApi->getGraphsCount();

    __INFO("qnn-api initialized with {} graph(s)", m_num_graphs);

    GraphInfo_t** graphs_info = m_qnnApi->getGraphsInfo();
    m_variant_list.reserve(m_num_graphs);
    std::map<int32_t, std::vector<std::string>> graph_names;
    for (size_t graph_idx = 0; graph_idx < m_num_graphs; graph_idx++) {
        GraphInfo_t* const graph_info = graphs_info[graph_idx];
        GraphVariant       graph(graph_info, m_qnnApi->getContexts(graph_info), m_ctx_size, m_layerNames);
        __DEBUG("qnn-htp: Graph {}", graph.graph_name);

        if (!variant_latency.empty() && !variant_latency.contains(graph.n_tokens)) {
            __WARN("qnn-htp: Disabling {} based on conf file", graph.graph_name);
            continue;
        }

        if (exec_select_graphs.size() != 0 &&
            std::find(exec_select_graphs.begin(), exec_select_graphs.end(), graph.graph_name) ==
                    exec_select_graphs.end()) {
            __DEBUG("qnn-htp: Graph {} is not selected to execute based on conf file",
                    graph.graph_name);
            continue;
        }
        m_variant_list.emplace_back(graph);
        n_splits = std::max(n_splits, ++nsp_graph_count[graph.n_tokens]);
        graph_names[graph.n_tokens].push_back(graph.graph_name);
        m_graph_map[std::string(graph_info->graphName)] = &m_variant_list.back();
    }

    if (exec_select_graphs.size() != 0 && graph_names.empty()) {
        __ERROR("No matching graphs based on conf file");
    }

    // Create NSPGraph for each splits
    m_nsp_graphs.reserve(n_splits);
    for (int idx = 0; idx < n_splits; idx++) {
        m_nsp_graphs.emplace_back(
                idx, _env, m_ctx_size, m_qnnApi.get(), m_ioTensor.get(), _threaded
        );
        m_nsp_graphs.back().setDebugMode(_debug_specs, _debug_tensors, _debug_path);
    }

    // Insert all GraphVariants into corresponding NSPGraph
    for (auto& [n_tokens, graphs] : graph_names) {
        std::sort(graphs.begin(), graphs.end());
        for (int idx = 0; idx < graphs.size(); idx++)
            m_nsp_graphs.at(idx).addGraph(m_graph_map.at(graphs[idx]));
    }

    if (_debug_specs) dumpTensorSpecs();

    {
        __INFO("qnn-htp: Graphs loaded (AR-n: #splits): {}", nsp_graph_count);

        // Check if latency map matches the graphs loaded
        if (!variant_latency.empty()) {
            for (auto [variant, latency] : variant_latency) {
                if (!nsp_graph_count.contains(variant)) {
                    __ERROR("Latency map (AR-n: #latency_ms): {}", variant_latency);
                    __ERROR("AR-{} present in latency map but not loaded!", variant);
                    __ERROR("Fix latency-map in the conf file, must map from AR-n to latency (ms)");
                    return false;
                }
            }
        }
    }

    __DEBUG("qnn-htp: Model Init complete: {} usec", start.elapsed_usec());

    return true;
}

// Once the model has been loaded, initialize IO Tensors
// m_ioTensors is initialized by the context for now
bool QnnNspModel::initializeIOTensors() {

    if(m_use_async_Init == false){ // IO Tensor Mem Registration is already done within the
                                   // model_initailize by Qnn_API for Sync Init.

        // set lmHeadWeightsEnabled and loraWeights Enabled
        _lmhead_weight_input = m_qnnApi->getLmHeadWeightInputEnabled();
        _lora_enabled = m_qnnApi->getLoraWeightEnabled();
         for (auto it = nsp_graph_count.rbegin(); it != nsp_graph_count.rend(); ++it) {
            for (QnnNspGraph& graph : m_nsp_graphs) {
                // TensorAllocInfo is added to each NSP graph.
                // Needed by Pointer_SHIFT Registration During Execute.
                graph.tensor_alloc_info = m_qnnApi->getTensorAllocInfo();
                if(graph.tensor_alloc_info == NULL){
                    __ERROR("Error Tensor Allocation Failed.");
                    return false;
                }
            }
        }
        return true;
    }

    // This path is used in case of use Async Init is true.
    qualla::Timer start;



    __DEBUG("qnn-htp: init IO tensors start");

    // Ideally, we should create and initalize m_ioTensor for each context, but we want to
    // be able to see/use all the buffers in every contexts so that they can be connected
    // with each other. Hence, we are using only the first context to initialize the m_ioTensor
    // and use it for all graphs/contexts.
    __DEBUG("qnn-htp: init IO tensor using {}", m_graph_map.begin()->first);
    if (true != m_ioTensor->initialize(m_graph_map.begin()->second->context_handle)) {
        __ERROR("qnn-htp: failure to initialize IOTensor");
        return false;
    }



    // Technical note: unordered_map is faster thans map but map makes debug logs easier to read
    // The runtime impact shouldn't be very large since max size < #tensors

    typedef int CtxBitVector;
    // Maps context bitVector to a map{tensor_name -> max_tensor_size}
    std::map<CtxBitVector, std::map<std::string, size_t>> ctx_alloc_map;
    // Maps tensor_name to context bitVector, each bit representing a context the tensor exists in
    std::map<std::string, CtxBitVector> tensor_ctx_map;
    // Maps a ContextHandle to a one-hot encoded bitVector (e.g. 1, 2, 4, ...)
    std::map<Qnn_ContextHandle_t, CtxBitVector> ctx_to_hash;

    // Iterate over all tensors in all GraphVariants to figure out allocations
    for (auto& variant : m_variant_list) {
        // Map the context handle to a hashed bitVector
        if (!ctx_to_hash.contains(variant.context_handle)) {
            ctx_to_hash[variant.context_handle] = 1 << ctx_to_hash.size();
        }
        for (auto& tensor_specs : {variant.input_specs, variant.output_specs}) {
            for (auto& [tname, tspec] : tensor_specs) {
                size_t       size     = tspec.dims.getAlignedSize();
                CtxBitVector tcontext = ctx_to_hash[variant.context_handle];

                // Check if it's LoRA enabled model
                if (!_lora_enabled && tname.find("lora") != std::string::npos) _lora_enabled = true;
                // Check if graph has lmhead weight input
                if (!_lmhead_weight_input && tname.compare("weight") == 0)
                    _lmhead_weight_input = true;

                // Allocate KV Tensors as in+out
                if (tname.starts_with("past_")) {
                    if (tname.ends_with("_in")) continue; // kv_in is processed along with kv_out

                    // For kv_out, add the size of kv_in as well
                    const std::string tname_in = tname.substr(0, tname.rfind('_')).append("_in");
                    if (auto tensor = variant.getInput(tname_in))
                        size += tensor->dims.getAlignedSize();

                    d_kv = QnnUtils::DataType(tspec.tensor);

                    // Allocate extra buffer for pointer shift
                    // 1024-n for keys (1024-n)*128 for values
                    // For aligned size, we might as well use 1024 and 128*1024
                    if (_kv_update_method == POINTER_SHIFT)
                        size += (tname.starts_with("past_key")) ? m_ctx_size * d_kv.bw()
                                                                : m_ctx_size * m_kv_dim * d_kv.bw();
                }

                if (tensor_ctx_map.contains(tname)) { // For duplicate tensor names, link them
                    CtxBitVector context_bitvec = tensor_ctx_map.at(tname);
                    size = std::max(ctx_alloc_map[context_bitvec][tname], size);
                    if ((context_bitvec & tcontext) == 0) // Set of contexts needs to be updated
                        ctx_alloc_map[context_bitvec].erase(tname);

                    tcontext |= context_bitvec;
                }

                ctx_alloc_map[tcontext][tname] = size;
                tensor_ctx_map[tname]          = tcontext;
            }
        }

        // Cleanup is essential in case of very large number of splits
        for (auto it = ctx_alloc_map.cbegin(); it != ctx_alloc_map.cend();)
            it = (it->second.empty()) ? ctx_alloc_map.erase(it) : ++it;
    }



    _env.logger().compose(Logger::MALLOC_DEBUG, [&](Logger::Helper w) {
        for (auto& [tcontext, tensor_alloc_map] : ctx_alloc_map) {
            w.write(fmt::format("qnn-htp: ctx_alloc_map[{}] = {{", tcontext));
            for (auto& [tname, tsize] : tensor_alloc_map)
                w.write(fmt::format("\t{} : {},", tname, tsize));
            w.write("}");
        }
    });

    // Calculate total allocation sizes and offset of each tensor within its allocated buffer
    if (m_ioTensor->allocateBuffers(ctx_alloc_map, tensor_alloc_info) == false) return false;

    _env.logger().compose(Logger::MALLOC_DEBUG, [&](Logger::Helper w) {
        w.write("tensor_alloc_info = {");
        for (auto& [tname, toffset] : tensor_alloc_info)
            w.write(fmt::format("\t{}: [{}, {}],", tname, toffset.first, toffset.second));
        w.write("}");
    });

    // For each variant, map tensor name to its allocated buffer, i/o and offset within the buffer
    // TODO: Check why we aren't just looping over all variants here!
    for (auto it = nsp_graph_count.rbegin(); it != nsp_graph_count.rend(); ++it) {

        for (QnnNspGraph& graph : m_nsp_graphs) {

            // TODO: Remove this reference
            graph.tensor_alloc_info = &tensor_alloc_info;

            auto variant = graph[it->first];

            std::map<std::string, std::tuple<int, size_t, size_t>> graph_allocs;


            for (auto& [tname, tspec] : variant->input_specs) {
                if (tname.starts_with("past_")) continue;
                auto& [alloc_idx, offset] = tensor_alloc_info.at(tname);
                graph_allocs[tname]       = {alloc_idx, offset, tspec.dims.getAlignedSize()};
            }

            for (auto& [tname, tspec] : variant->output_specs) {
                size_t kv_offset = 0;
                size_t size      = tspec.dims.getAlignedSize();

                auto& [alloc_idx, offset] = tensor_alloc_info.at(tname);
                if (tname.starts_with("past_")) {
                    auto in_name = tname.substr(0, tname.rfind("_")).append("_in");
                    if (auto kv_in = variant->getInput(in_name)) {
                        kv_offset = kv_in->dims.getAlignedSize();
                        if (_kv_update_method == POINTER_SHIFT)
                            kv_offset += (tname.starts_with("past_key"))
                                                 ? m_ctx_size * d_kv.bw()
                                                 : m_ctx_size * m_kv_dim * d_kv.bw();
                        graph_allocs[in_name] = {alloc_idx, offset, kv_offset};
                    }
                }

                graph_allocs[tname] = {alloc_idx, offset + kv_offset, size};
            }

            if (!m_ioTensor->mapFusedBufferOffset(
                        variant->graph_info, variant->context_handle, graph_allocs
                )) {

                __ERROR("Error mapping tensor to allocation buffers");
                return false;
            }
        }
    }



    __DEBUG("qnn-htp: init IO tensors complete : {} usec", start.elapsed_usec());

    return true;
}

static bool checkShape(
        const std::string&                                              tensor_name,
        const QnnUtils::Tensor*                                         tensor,
        int32_t                                                         height,
        int32_t                                                         width,
        int32_t                                                         channel,
        int32_t                                                         bitWidth,
        std::vector<std::tuple<std::string, std::string, std::string>>& errors
) {
    if (tensor == nullptr) return true;
    const QnnUtils::Dims& tDims = tensor->dims;

    if ((height == -1 || height == tDims.height) && (width == -1 || width == tDims.width) &&
        (channel == -1 || channel == tDims.channel) &&
        (bitWidth == -1 || bitWidth == tDims.bitWidth))
        return true;

    std::stringstream err_msg;
    err_msg << "Expected [ " << height << ", " << width << ", " << channel << "] "
            << "bitWidth=" << bitWidth << ". Found [ " << tDims.height << ", " << tDims.width
            << ", " << tDims.channel << "] "
            << "bitWidth=" << tDims.bitWidth;

    errors.push_back({"ShapeError", tensor_name, err_msg.str()});
    return false;
}

// Run all validations for the model here so we can exit early
bool QnnNspModel::validateModel() {
    // Checks we will be running
    // 1a. input_ids or inputs_embeds exists in the first split
    // 1b. token_type_ids should exists in case of Bert
    // 2. logits exists in the last split
    // 3. Shapes for all named tensors are correct
    // 4. All tensors with identical names (incl kv_in/kv_out) have identical quantization params
    // Missing check : Shape of tensor between splits match up

    // Support for 16-bit KV Tensors is temporarily disabled
    // If you need this, please refer to past commits (QuaLLA <= v0.3.22)

    // Important : These variables need to be set correctly
    // m_vocab_size  - Calculated as max(logits.shape) since len()
    // m_kv_dim      - Calculated in this function before usage
    // m_ctx_size    - Provided by the user as n_ctx

    std::vector<std::tuple<std::string, std::string, std::string>> errors;

    QnnUtils::Tensor* tt;

    //default input type is token
    m_inputType = InputType::TOKENS;

    // Check 1 - input layer exists
    for (auto& [n_tokens, variant] : m_nsp_graphs.front().variants) {
        // Update model expectations for E2T if an inputs_embeds layer is present. marks the input Type
        if ((tt = variant->getInput("inputs_embeds")) != nullptr) {
            m_layerNames[LayerType::INPUT] = "inputs_embeds";
            m_inputType = InputType::EMBEDDINGS;
        }
        if ((tt = variant->getInput(m_layerNames[LayerType::INPUT])) == nullptr) {
            errors.push_back({variant->graph_name, m_layerNames[LayerType::INPUT], "Tensor not found"});
        } else {
            input_bitWidth = tt->dtype.bw();
            checkShape(m_layerNames[LayerType::INPUT], tt, -1, -1, -1, input_bitWidth, errors);

            if (embedding_datatype == "float32") {
                m_embeddingBufferSize = m_embd_size * sizeof(float);
            } else {
                m_embeddingBufferSize = m_embd_size * input_bitWidth;
            }

            // For embedding inputs, the expected count is multiplied by the embedding size.
            size_t expectedElementCount = (m_inputType == InputType::TOKENS) ? n_tokens : n_tokens * m_embd_size;
            if (tt->dims.getNumElements() != expectedElementCount)
                errors.push_back({variant->graph_name, m_layerNames[LayerType::INPUT], "Wrong input shape"});
        }
    }

    // Check 1b - In case of BERT :-> token_type_ids
    if(m_modelArchitectureType == ModelArchitectureType::ENCODER) {
        for (auto &[n_tokens, variant]: m_nsp_graphs.front().variants) {
            if ((tt = variant->getInput(m_layerNames[LayerType::TOKEN_TYPE_IDS])) == nullptr)
                errors.push_back({variant->graph_name, m_layerNames[LayerType::TOKEN_TYPE_IDS], "Tensor not found"});
            else {
                checkShape(m_layerNames[LayerType::TOKEN_TYPE_IDS], tt, -1, -1, -1, 4, errors);
                if (tt->dims.getNumElements() != n_tokens)
                    errors.push_back({variant->graph_name, m_layerNames[LayerType::TOKEN_TYPE_IDS],
                                      "Wrong token_type_ids shape"});
            }
        }
    }

    // Check 2 - In case of LLama :-> logits exists
    //           In case of BERT :-> pooled_output & sequence_outputs exists
    for (auto& [n_tokens, variant] : m_nsp_graphs.back().variants) {
        if (m_modelArchitectureType == ModelArchitectureType::ENCODER) {
            if ((tt = variant->getOutput(m_layerNames[LayerType::POOL_OUTPUT])) == nullptr)
                errors.push_back({variant->graph_name, m_layerNames[LayerType::POOL_OUTPUT], "Tensor not found"});
            else {
                if (tt->dims.getNumElements() != m_embd_size)
                    errors.push_back(
                            {variant->graph_name, m_layerNames[LayerType::POOL_OUTPUT], "Wrong pooled_outputs shape"});

            }
            if (!m_pooled_output) {
                if ((tt = variant->getOutput(m_layerNames[LayerType::SEQ_OUTPUT])) == nullptr)
                    errors.push_back({variant->graph_name, m_layerNames[LayerType::SEQ_OUTPUT], "Tensor not found"});
                else {
                    if (tt->dims.getNumElements() != n_tokens * m_embd_size)
                        errors.push_back({variant->graph_name, m_layerNames[LayerType::SEQ_OUTPUT],
                                          "Wrong sequence_output shape"});

                }
            }
        } else {
            if ((tt = variant->getOutput(m_layerNames[LayerType::OUTPUT])) == nullptr)
                errors.push_back({variant->graph_name, m_layerNames[LayerType::OUTPUT], "Tensor not found"});
            else {
                if (m_vocab_size == -1) m_vocab_size = tt->dims.getMaxDim();
                if (tt->dims.getNumElements() != m_vocab_size &&
                    tt->dims.getNumElements() != n_tokens * m_vocab_size)
                    errors.push_back({variant->graph_name, m_layerNames[LayerType::OUTPUT], "Wrong logits shape"});
            }
        }
    }

    // Check 3 - Shapes for all names tensors are correct
    if (m_kv_dim == -1) { // Deduce KV$ embed_dim if not already available
        for (auto& variant : m_variant_list) {
            for (auto& [tname, tspec] : variant.output_specs)
                if (tname.starts_with("past_key")) m_kv_dim = tspec.dims.width;
            if (m_kv_dim != -1) break;
        }
    }

    for (auto& variant : m_variant_list) {
        auto& n_tokens = variant.n_tokens;
        if(m_modelArchitectureType == ModelArchitectureType::ENCODER){
            checkShape(m_layerNames[LayerType::ATTN_MASK], variant.getInput(m_layerNames[LayerType::ATTN_MASK]), 1, 1, m_ctx_size, -1, errors);
        }
        else{
            checkShape(m_layerNames[LayerType::ATTN_MASK], variant.getInput(m_layerNames[LayerType::ATTN_MASK]), 1, n_tokens, m_ctx_size, -1, errors);
        }
        if (m_positional_encoding.type == PositionalEncoding::ROPE) {
            checkShape(m_layerNames[LayerType::POS_SIN], variant.getInput(m_layerNames[LayerType::POS_SIN]), 1, n_tokens, m_pos_dim, -1, errors);
            checkShape(m_layerNames[LayerType::POS_COS], variant.getInput(m_layerNames[LayerType::POS_COS]), 1, n_tokens, m_pos_dim, -1, errors);
        } else if (m_positional_encoding.type == PositionalEncoding::ABSOLUTE) {
            checkShape(m_layerNames[LayerType::POS_IDS], variant.getInput(m_layerNames[LayerType::POS_IDS]), 1, 1, n_tokens, -1, errors);
        } else if (m_positional_encoding.type == PositionalEncoding::ALIBI) {
            checkShape(m_layerNames[LayerType::POS_IDS], variant.getInput(m_layerNames[LayerType::POS_IDS]), 1, n_tokens, m_ctx_size, -1, errors);
        }

        if(m_modelArchitectureType != ModelArchitectureType::ENCODER) {
            for (auto &[tname, tspec]: variant.input_specs) {
                if (tname.starts_with("past_key"))
                    checkShape(tname, &tspec, -1, m_kv_dim, m_ctx_size - n_tokens, 1, errors);
                else if (tname.starts_with("past_value"))
                    checkShape(tname, &tspec, -1, m_ctx_size - n_tokens, m_kv_dim, 1, errors);
            }

            for (auto &[tname, tspec]: variant.output_specs) {
                if (tname.starts_with("past_key"))
                    checkShape(tname, &tspec, -1, m_kv_dim, n_tokens, 1, errors);
                else if (tname.starts_with("past_value"))
                    checkShape(tname, &tspec, -1, n_tokens, m_kv_dim, 1, errors);
            }
        }
    }

    // skip check in case of BERT architecture since no KV cache tensors are existing
    if(m_modelArchitectureType != ModelArchitectureType::ENCODER) {
        // Check 4 - Quantization parameter match
        std::unordered_map<std::string, QnnUtils::QuantParam> quant_params;
        for (auto &variant: m_variant_list) {
            for (auto &tensor_specs: {variant.input_specs, variant.output_specs}) {
                for (auto &[tname, tspec]: tensor_specs) {
                    std::string name = (tname.starts_with("past_") && tname.ends_with("_in"))
                                       ? tname.substr(0, tname.rfind("_")).append("_out")
                                       : tname;
                    if (name.compare(m_layerNames[LayerType::OUTPUT]) == 0) continue;
                    if (quant_params.contains(name)) {
                        if (quant_params.at(name).scale != tspec.quantParam[0].scale ||
                            quant_params.at(name).offset != tspec.quantParam[0].offset)
                            errors.push_back(
                                    {variant.graph_name,
                                     tname,
                                     "Non-identical quantization parameters found for the same tensor"}
                            );
                    } else
                        quant_params[tname] = {tspec.quantParam[0].scale, tspec.quantParam[0].offset};
                }
            }
        }
    }

    if (errors.size() > 0) {
        QNN_ERROR("Model Validation Errors found");
        for (auto& [graph_name, tensor_name, err_msg] : errors) // Log the list of errors
            QNN_ERROR("%s : %s - %s", graph_name.c_str(), tensor_name.c_str(), err_msg.c_str());
        QNN_ERROR("Note: -1 means ignore (i.e. no comparison)");
        QNN_ERROR("Check model i/o specs (set dump-specs=true in config) for debugging");
        return false;
    }

    return true;
}

bool QnnNspModel::initializeKVManager() {

    if(m_disableKvCache){
        return true;
    }

    // Pick the largest variant
    int32_t variant = nsp_graph_count.rbegin()->first;

    int idx = 0;
    for (auto& graph : m_nsp_graphs) {
        auto& specs = (variant == m_ctx_size) ? graph[variant]->output_specs
                                              : graph[variant]->input_specs;

        ThreadPool* _pool = _threaded ? &threadpool : nullptr;
        // clang-format off
        NewNSPKVManager *manager = new NewNSPKVManager( idx++, _env, _pool, m_ioTensor.get(),
                specs, m_ctx_size, m_kv_dim, _kv_update_method);
        // clang-format on
        graph.registerKVManager(manager);

        if (_kv_update_method == POINTER_SHIFT)
            graph.kvmanager->setTensorAllocInfo(&tensor_alloc_info);
    }

    _kv_dispatcher =
            std::unique_ptr<KVDispatcher>(new KVDispatcher(_env, m_nsp_graphs, _threaded, _cpumask)
            );
    _kv_update_count = _kv_dispatcher->dispatch(variant, 0);

    return true;
}

inline bool QnnNspModel::updateTensorPointer(
        GraphVariant&      variant,
        std::string&       key,
        QnnUtils::Tensor*& t
) {
    QnnUtils::Tensor* tensor_ptr = variant.getInput(key);
    if (tensor_ptr == nullptr) return true;
    if (t == nullptr) t = tensor_ptr;
    if (getBuffer(t) == getBuffer(tensor_ptr)) return true;

    __ERROR("{} has different addresses: {} vs {}", key, (void*)t, (void*)tensor_ptr);
    return false;
}

bool QnnNspModel::initializeTensorPointers() {
    // Ideally this needs to be done for all sets of AR-n available, e.g. for AR-1 and AR-1024

    bool status = true;
    for (auto& variant : m_variant_list) {
        status &= updateTensorPointer(variant, m_layerNames[LayerType::INPUT], t_input_ids);
        status &= updateTensorPointer(variant, m_layerNames[LayerType::ATTN_MASK], t_attn_mask);
        status &= updateTensorPointer(variant, m_layerNames[LayerType::POS_SIN], t_position_ids_sin);
        status &= updateTensorPointer(variant, m_layerNames[LayerType::POS_COS], t_position_ids_cos);
        status &= updateTensorPointer(variant, m_layerNames[LayerType::POS_IDS], t_position_ids);
        status &= updateTensorPointer(variant, m_layerNames[LayerType::TOKEN_TYPE_IDS], t_token_type_ids);
    }
    if (!status) __ERROR("qnn-htp: Error in setting up named tensor pointers.");

    status &= !(!t_input_ids || !t_attn_mask);
    if (!t_input_ids) __ERROR("Tensor not found: {}", m_layerNames[LayerType::INPUT]);
    if (!t_attn_mask) __ERROR("Tensor not found: {}", m_layerNames[LayerType::ATTN_MASK]);

    if(m_modelArchitectureType == ModelArchitectureType::ENCODER){ // This input only valid for Encoder only model like bert.
        status &= !(!t_token_type_ids);
        if (!t_token_type_ids) __ERROR("Tensor not found: {}", m_layerNames[LayerType::TOKEN_TYPE_IDS]);
    }

    if (m_positional_encoding.type == PositionalEncoding::ROPE) {
        status &= !(!t_position_ids_sin || !t_position_ids_cos);
        if (!t_position_ids_sin) __ERROR("Tensor not found: {}", m_layerNames[LayerType::POS_SIN]);
        if (!t_position_ids_cos) __ERROR("Tensor not found: {}", m_layerNames[LayerType::POS_COS]);
    } else if (m_positional_encoding.type == PositionalEncoding::ABSOLUTE) {
        status &= !(!t_position_ids);
        if (!t_position_ids) __ERROR("Tensor not found: {}", m_layerNames[LayerType::POS_IDS]);
    } else if (m_positional_encoding.type == PositionalEncoding::ALIBI) {
        status &= !(!t_position_ids);
        if (!t_position_ids) __ERROR("Tensor not found: {}", m_layerNames[LayerType::POS_IDS]);
    } else {
        __ERROR("Unknown Rope Type found for tensor: {}", m_layerNames[LayerType::POS_IDS]);
    }

    // Detect activation bitwidth
    if (status) {
        //Check Input-> Input_ID or Input_Embed
        d_input      = t_input_ids->dtype;
        if (!supported_activations.contains(d_input)) {
            __ERROR("Input Tensor: {} as unsupported activation type {}", m_layerNames[LayerType::INPUT], d_input.str());
            status = false;
        }
        // Check Attention Mask
        d_attn_map   = t_attn_mask->dtype;
        if (!supported_activations.contains(d_attn_map)) {
            __ERROR("attention_mask has unsupported type {}", d_attn_map.str());
            status = false;
        }
        // For Encoder only model, Check for Token_type_ids
        if(m_modelArchitectureType == ModelArchitectureType::ENCODER) {
            d_token_type = t_token_type_ids->dtype;
            if (!supported_activations.contains(d_token_type)) {
                __ERROR("token_type_ids has unsupported type {}", d_token_type.str());
                status = false;
            }
        }

        //For Position_IDs check data bitWidth
        if (m_positional_encoding.type == PositionalEncoding::ROPE)
            d_pos = t_position_ids_sin->dtype;
        else if (m_positional_encoding.type == PositionalEncoding::ABSOLUTE)
            d_pos = t_position_ids->dtype;
        else if (m_positional_encoding.type == PositionalEncoding::ALIBI)
            d_pos = t_position_ids->dtype;

        if (((m_positional_encoding.type == PositionalEncoding::ABSOLUTE ||
                m_positional_encoding.type == PositionalEncoding::ALIBI) &&
                d_pos != QNN_DATATYPE_INT_32) ||
                (m_positional_encoding.type == PositionalEncoding::ROPE &&
                !supported_activations.contains(d_pos))) {
                __ERROR("position encoding tensor has unsupported type {}", d_pos.str());
                status = false;
        }
        __DEBUG("qnn-htp datatypes: d_input {} d_attn_map {} d_pos {} d_kv {}",
                d_input.str(),
                d_attn_map.str(),
                d_pos.str(),
                d_kv.str());

        if (!status) __ERROR("Only 8-bit, 16-bit and 32-bit activations are supported");
    }

    return status;
}
bool QnnNspModel::setupAttentionMaskFP16(bool                     pad_left,
                                         int                      n_tokens,
                                         int                      n_inputs,
                                         int                      n_past,
                                         std::span<const int32_t> attention_map,
                                         size_t                   n_skip_prefix,
                                         size_t                   n_apply_prefix_offset) {
  QnnUtils::Dims t_attn_mask_dims = t_attn_mask->dims;
  size_t numElements = t_attn_mask_dims.getNumElements();
  size_t bufSize = numElements * 2; // (bitwidth = 16, in bytes: 16/8)
  std::vector<unsigned char> attn_mask_vec(bufSize);
  if (!float32ToFloat16((unsigned char *)attn_mask_vec.data(), (float *) getBuffer(t_attn_mask), numElements)) {
    QNN_ERROR("Number of elements is 0");
    return false;
  }
  // Setup attention mask
  {
    uint16_t*    attn_buffer = (uint16_t*)attn_mask_vec.data();
    const int n_valid     = n_past + n_inputs;

    uint16_t pos_val = -1, neg_val = 0;
    pos_val = 0;
    neg_val = -1000;

    // Clear the attention mask
    std::fill_n(attn_buffer, n_tokens * m_ctx_size, neg_val);
    if (attention_map.empty()) {
      uint16_t* cur_ptr = &attn_buffer
      [(pad_left) ? (m_ctx_size - n_valid) * (m_ctx_size + 1)
                  : m_ctx_size - n_past - n_tokens];
      for (int n_masked = n_past + 1; n_masked <= n_valid; n_masked++) {
        std::fill_n(cur_ptr, n_masked, pos_val);
        cur_ptr += m_ctx_size;
      }
    } else if (attention_map.size() == n_inputs) {
      // Only fill in n_inputs. Rest will be padding
      const size_t attn_row_start = m_ctx_size - n_past - n_tokens;
      for (int i = 0; i < n_inputs; i++) {
        uint16_t* cur_ptr = &attn_buffer[i * m_ctx_size + attn_row_start];

        cur_ptr[n_past + i] = pos_val; // Attend to itself
        if (attention_map[i] < 0) {    // If negative, attend to only past tokens
          int32_t n_masked = n_past + attention_map[i] + 1;
          if (i < n_apply_prefix_offset) { // Skip prefix is needed
            cur_ptr += n_skip_prefix;
            n_masked -= n_skip_prefix;
          }
          std::fill_n(cur_ptr, n_masked, pos_val);

        } else { // If positive, copy attention map from (relative to 0th input) parent
          const int32_t pidx       = attention_map[i]; // Parent token index
          uint16_t*        parent_ptr = &attn_buffer[pidx * m_ctx_size + attn_row_start];
          std::memcpy(cur_ptr, parent_ptr, (n_past + pidx + 1) * sizeof(uint16_t));

          // If parent skipped prefix, but this token needs to attend to prefix, add attn
          if (i >= n_apply_prefix_offset && pidx < n_apply_prefix_offset)
            std::fill_n(cur_ptr, n_skip_prefix, pos_val);
        }
      }
    } else if (attention_map.size() == n_valid * n_inputs) {
      uint16_t* cur_ptr = &attn_buffer[m_ctx_size - n_past - n_tokens];
      for (int i = 0; i < n_inputs; i++) {
        for (int j = 0; j < n_valid; j++)
          cur_ptr[j] = (attention_map[i * n_valid + j] == 0) ? neg_val : pos_val;
        cur_ptr += m_ctx_size;
      }
    }
  }

  return true;

}
template <typename DType>
bool QnnNspModel::setupAttentionMask(
        bool                     pad_left,
        int                      n_tokens,
        int                      n_inputs,
        int                      n_past,
        std::span<const int32_t> attention_map,
        size_t                   n_skip_prefix,
        size_t                   n_apply_prefix_offset
) {
    // Setup attention mask
    {
        DType*    attn_buffer = (DType*)getBuffer(t_attn_mask);
        const int n_valid     = n_past + n_inputs;

        DType pos_val = -1, neg_val = 0;

        if(m_modelArchitectureType == ModelArchitectureType::ENCODER){
            pos_val = 1; // BGE model is using 1 to set attention mask and 0 to unset.
            std::memset(attn_buffer, neg_val, 1 * m_ctx_size * sizeof(DType));
            size_t in_buf_offset = pad_left ? m_ctx_size - n_valid : 0;
            DType* cur_ptr = &attn_buffer[in_buf_offset];
            std::fill_n(cur_ptr, n_valid, pos_val);
        }
        else {
            // Clear the attention mask
            std::fill_n(attn_buffer, n_tokens * m_ctx_size, neg_val);
            if (attention_map.empty()) {
                DType *cur_ptr = &attn_buffer
                [(pad_left) ? (m_ctx_size - n_valid) * (m_ctx_size + 1)
                            : m_ctx_size - n_past - n_tokens];
                for (int n_masked = n_past + 1; n_masked <= n_valid; n_masked++) {
                    std::fill_n(cur_ptr, n_masked, pos_val);
                    cur_ptr += m_ctx_size;
                }
            } else if (attention_map.size() == n_inputs) {
                // Only fill in n_inputs. Rest will be padding
                const size_t attn_row_start = m_ctx_size - n_past - n_tokens;
                for (int i = 0; i < n_inputs; i++) {
                    DType *cur_ptr = &attn_buffer[i * m_ctx_size + attn_row_start];

                    cur_ptr[n_past + i] = pos_val; // Attend to itself
                    if (attention_map[i] < 0) {    // If negative, attend to only past tokens
                        int32_t n_masked = n_past + attention_map[i] + 1;
                        if (i < n_apply_prefix_offset) { // Skip prefix is needed
                            cur_ptr += n_skip_prefix;
                            n_masked -= n_skip_prefix;
                        }
                        std::fill_n(cur_ptr, n_masked, pos_val);

                    } else { // If positive, copy attention map from (relative to 0th input) parent
                        const int32_t pidx = attention_map[i]; // Parent token index
                        DType *parent_ptr = &attn_buffer[pidx * m_ctx_size + attn_row_start];
                        std::memcpy(cur_ptr, parent_ptr, (n_past + pidx + 1) * sizeof(DType));

                        // If parent skipped prefix, but this token needs to attend to prefix, add attn
                        if (i >= n_apply_prefix_offset && pidx < n_apply_prefix_offset)
                            std::fill_n(cur_ptr, n_skip_prefix, pos_val);
                    }
                }
            } else if (attention_map.size() == n_valid * n_inputs) {
                DType *cur_ptr = &attn_buffer[m_ctx_size - n_past - n_tokens];
                for (int i = 0; i < n_inputs; i++) {
                    for (int j = 0; j < n_valid; j++)
                        cur_ptr[j] = (attention_map[i * n_valid + j] == 0) ? neg_val : pos_val;
                    cur_ptr += m_ctx_size;
                }
            }
        }
    }

    return true;
}
    bool QnnNspModel::setupRopePositionEmbeddingFP16(
            bool                     pad_left,
            int                      n_tokens,
            int                      n_inputs,
            int                      n_past,
            std::span<const int32_t> attention_map,
            size_t                   n_skip_prefix,
            size_t                   n_apply_prefix_offset
    ) {
      const int n_valid = n_past + n_inputs;

      // Cast RoPE embeddings to proper dtype
      // The following two buffers are already converted to fp16
      uint16_t* typed_rope_sin = (uint16_t*)rope_sin;
      uint16_t* typed_rope_cos = (uint16_t*)rope_cos;

      // These two need conversion

      QnnUtils::Dims t_position_ids_cos_dims = t_position_ids_cos->dims;
      size_t numElements = t_position_ids_cos_dims.getNumElements();
      size_t bufSize = numElements * 2; // (bitwidth = 16, in bytes: 16/8)
      std::vector<unsigned char> position_ids_cos_vec(bufSize);
      if (!float32ToFloat16((unsigned char *)position_ids_cos_vec.data(), (float *) getBuffer(t_position_ids_cos), numElements)) {
        QNN_ERROR("Number of elements is 0");
        return false;
      }
      uint16_t* cos_buffer = (uint16_t*)position_ids_cos_vec.data();

      QnnUtils::Dims t_position_ids_sin_dims = t_position_ids_sin->dims;
      numElements = t_position_ids_sin_dims.getNumElements();
      bufSize = numElements * 2; // (bitwidth = 16, in bytes: 16/8)
      std::vector<unsigned char> position_ids_sin_vec(bufSize);
      if (!float32ToFloat16((unsigned char *)position_ids_sin_vec.data(), (float *) getBuffer(t_position_ids_sin), numElements)) {
        QNN_ERROR("Number of elements is 0");
        return false;
      }
      uint16_t* sin_buffer = (uint16_t*)position_ids_sin_vec.data();

      // Clear out all position_ids as position_sin/cos[0]
      const size_t pos_row_size = m_pos_dim * sizeof(uint16_t);
      for (int i = 0; i < n_tokens; i++) {
        std::memcpy(&sin_buffer[i * m_pos_dim], typed_rope_sin, pos_row_size);
        std::memcpy(&cos_buffer[i * m_pos_dim], typed_rope_cos, pos_row_size);
      }

      // Copy in position embeddings [0:(n_valid-1)] to input sin/cos buffer
      const size_t pos_buf_offset = m_pos_dim * ((pad_left) ? m_ctx_size - n_valid : 0);
      if (attention_map.size() == n_inputs) {
        // Copy embeddings one by one based on the attention map
        std::vector<int32_t> pos_ids(n_inputs, 0);
        auto                 sin = &sin_buffer[pos_buf_offset];
        auto                 cos = &cos_buffer[pos_buf_offset];

        // 1st token
        pos_ids[0] = m_nPast - n_skip_prefix;
        std::memcpy(sin, &typed_rope_sin[pos_ids[0] * m_pos_dim], pos_row_size);
        std::memcpy(cos, &typed_rope_cos[pos_ids[0] * m_pos_dim], pos_row_size);
        sin += m_pos_dim;
        cos += m_pos_dim;

        // Rest
        for (int i = 1; i < n_inputs; i++) {
          auto parent_index = attention_map[i];
          pos_ids[i]        = pos_ids[parent_index] + 1;
          std::memcpy(sin, &typed_rope_sin[pos_ids[i] * m_pos_dim], pos_row_size);
          std::memcpy(cos, &typed_rope_cos[pos_ids[i] * m_pos_dim], pos_row_size);
          sin += m_pos_dim;
          cos += m_pos_dim;
        }
      } else if (attention_map.size() == (n_past + n_inputs) * n_inputs) {
        // For now, simply have the same position ID across the variant
        auto sin = &sin_buffer[0];
        auto cos = &cos_buffer[0];

        // Calculate position based on number of items this index is attending to
        for (int i = 0; i < n_inputs; i++) {
          auto    attn_row = attention_map.subspan(i * n_valid, n_valid);
          int32_t pos_id =
                  std::accumulate(attn_row.begin() + n_skip_prefix, attn_row.end(), 0) - attn_row[n_past + i];

          // __DEBUG("PositionID [ i={}, n_past={}, pos_id={} ]", i, n_past, pos_id);

          std::memcpy(sin, &typed_rope_sin[pos_id * m_pos_dim], pos_row_size);
          std::memcpy(cos, &typed_rope_cos[pos_id * m_pos_dim], pos_row_size);
          sin += m_pos_dim;
          cos += m_pos_dim;
        }
      } else {
        const size_t pos_dat_offset = m_pos_dim * (n_past - n_skip_prefix);
        const size_t pos_cpy_amt    = pos_row_size * ((pad_left) ? n_valid : n_tokens);
        std::memcpy(&sin_buffer[pos_buf_offset], &typed_rope_sin[pos_dat_offset], pos_cpy_amt);
        std::memcpy(&cos_buffer[pos_buf_offset], &typed_rope_cos[pos_dat_offset], pos_cpy_amt);
      }

      return true;
    }
template <typename DType>
bool QnnNspModel::setupRopePositionEmbedding(
        bool                     pad_left,
        int                      n_tokens,
        int                      n_inputs,
        int                      n_past,
        std::span<const int32_t> attention_map,
        size_t                   n_skip_prefix,
        size_t                   n_apply_prefix_offset
) {

    const int n_valid = n_past + n_inputs;

    // Cast RoPE embeddings to proper dtype
    DType* typed_rope_sin = (DType*)rope_sin;
    DType* typed_rope_cos = (DType*)rope_cos;

    DType* cos_buffer = (DType*)getBuffer(t_position_ids_cos);
    DType* sin_buffer = (DType*)getBuffer(t_position_ids_sin);

    // Clear out all position_ids as position_sin/cos[0]
    const size_t pos_row_size = m_pos_dim * sizeof(DType);
    for (int i = 0; i < n_tokens; i++) {
        std::memcpy(&sin_buffer[i * m_pos_dim], typed_rope_sin, pos_row_size);
        std::memcpy(&cos_buffer[i * m_pos_dim], typed_rope_cos, pos_row_size);
    }

    // Copy in position embeddings [0:(n_valid-1)] to input sin/cos buffer
    const size_t pos_buf_offset = m_pos_dim * ((pad_left) ? m_ctx_size - n_valid : 0);
    if (attention_map.size() == n_inputs) {
        // Copy embeddings one by one based on the attention map
        std::vector<int32_t> pos_ids(n_inputs, 0);
        auto                 sin = &sin_buffer[pos_buf_offset];
        auto                 cos = &cos_buffer[pos_buf_offset];

        // 1st token
        pos_ids[0] = m_nPast - n_skip_prefix;
        std::memcpy(sin, &typed_rope_sin[pos_ids[0] * m_pos_dim], pos_row_size);
        std::memcpy(cos, &typed_rope_cos[pos_ids[0] * m_pos_dim], pos_row_size);
        sin += m_pos_dim;
        cos += m_pos_dim;

        // Rest
        for (int i = 1; i < n_inputs; i++) {
            auto parent_index = attention_map[i];
            pos_ids[i]        = pos_ids[parent_index] + 1;
            std::memcpy(sin, &typed_rope_sin[pos_ids[i] * m_pos_dim], pos_row_size);
            std::memcpy(cos, &typed_rope_cos[pos_ids[i] * m_pos_dim], pos_row_size);
            sin += m_pos_dim;
            cos += m_pos_dim;
        }
    } else if (attention_map.size() == (n_past + n_inputs) * n_inputs) {
        // For now, simply have the same position ID across the variant
        auto sin = &sin_buffer[0];
        auto cos = &cos_buffer[0];

        // Calculate position based on number of items this index is attending to
        for (int i = 0; i < n_inputs; i++) {
            auto    attn_row = attention_map.subspan(i * n_valid, n_valid);
            int32_t pos_id =
                    std::accumulate(attn_row.begin() + n_skip_prefix, attn_row.end(), 0) - attn_row[n_past + i];

            // __DEBUG("PositionID [ i={}, n_past={}, pos_id={} ]", i, n_past, pos_id);

            std::memcpy(sin, &typed_rope_sin[pos_id * m_pos_dim], pos_row_size);
            std::memcpy(cos, &typed_rope_cos[pos_id * m_pos_dim], pos_row_size);
            sin += m_pos_dim;
            cos += m_pos_dim;
        }
    } else {
        const size_t pos_dat_offset = m_pos_dim * (n_past - n_skip_prefix);
        const size_t pos_cpy_amt    = pos_row_size * ((pad_left) ? n_valid : n_tokens);
        std::memcpy(&sin_buffer[pos_buf_offset], &typed_rope_sin[pos_dat_offset], pos_cpy_amt);
        std::memcpy(&cos_buffer[pos_buf_offset], &typed_rope_cos[pos_dat_offset], pos_cpy_amt);
    }

    return true;
}

template <typename DType>
bool QnnNspModel::setupAlibiPositionEmbedding(
        bool pad_left,
        int  n_tokens,
        int  n_inputs,
        int  n_past
) {
    DType* alibi_buffer = (DType*)getBuffer(t_position_ids);

    const int   n_valid = n_past + n_inputs;
    const DType pad_val = m_ctx_size;

    // Clear alibi buffer
    std::fill_n(alibi_buffer, n_tokens * m_ctx_size, pad_val);

    // Detect start of past tokens and new tokens based on m_ctx_size and n_tokens (variant)
    DType* alibi_past = alibi_buffer;                         // [0, m_ctx_size-n_tokens)
    DType* alibi_new  = alibi_buffer + m_ctx_size - n_tokens; // [m_ctx_size-n_tokens, m_ctx_size)

    // For non SMART_MASK, past tokens/KV$ is left-padded and past ptr needs to be offset by padding
    alibi_past += m_ctx_size - n_tokens - n_past;

    // For left padded inputs, new pointer needs to be offset by n_tokens - n_inputs
    if (pad_left) {
        alibi_new += n_tokens - n_inputs;
        alibi_past += (n_tokens - n_inputs) * m_ctx_size;
        alibi_new += (n_tokens - n_inputs) * m_ctx_size;
    }

    // Fill alibi positions from [-n_past-i, -i) and [-i, 0]
    for (int i = 0; i < n_inputs; i++) {
        std::iota(
                std::reverse_iterator<DType*>(alibi_past + n_past),
                std::reverse_iterator<DType*>(alibi_past),
                i + 1
        ); // Fill past tokens
        std::iota(
                std::reverse_iterator<DType*>(alibi_new + i + 1),
                std::reverse_iterator<DType*>(alibi_new),
                0
        ); // Fill new tokens

        alibi_past += m_ctx_size; // Update pointers to next row
        alibi_new += m_ctx_size;
    }

    return true;
}

bool QnnNspModel::setupInputTensors(
        std::span<int32_t>       tokens,
        int32_t                  n_past,
        std::span<const int32_t> attention_map,
        size_t                   n_skip_prefix,
        size_t                   n_apply_prefix_offset
) {
    qualla::Timer start;

    const int     n_tokens = run_info.n_tokens;
    const int     n_inputs = run_info.n_processed;
    const int32_t n_valid  = n_past + n_inputs;
    __TRACE("qnn-htp: setup-input-tensors with {} tokens for AR-{}", n_inputs, n_tokens);

    const bool pad_left = (n_tokens == m_ctx_size);
    if (n_inputs > n_tokens) {
        __ERROR("qnn-htp: setup-input-tensors too many tokens: {} on AR-{}", n_inputs, n_tokens);
        return false;
    }

    // Setup input id tensor
    {
        uint32_t* input_id_buffer = (uint32_t*)getBuffer(t_input_ids);
        std::fill_n(input_id_buffer, n_tokens, static_cast<uint32_t>(m_pad_token));

        size_t in_buf_offset = pad_left ? n_tokens - n_inputs : 0;
        std::memcpy(&input_id_buffer[in_buf_offset], tokens.data(), n_inputs * sizeof(uint32_t));
    }

    // clang-format off
    switch (d_attn_map) {
    case QNN_DATATYPE_UFIXED_POINT_8:
        setupAttentionMask<uint8_t>(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
    case QNN_DATATYPE_UFIXED_POINT_16:
        setupAttentionMask<uint16_t>(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
    case QNN_DATATYPE_INT_32:
        setupAttentionMask<int32_t>(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
    case QNN_DATATYPE_FLOAT_16: {
        setupAttentionMaskFP16(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix,
                                     n_apply_prefix_offset);
        break;
    }
    default: __ERROR("Unsupported attention mask dtype {}", d_attn_map.str()); return false;
    }
    // clang-format on

    // Setup token type IDs
    if(m_modelArchitectureType == ModelArchitectureType::ENCODER) {
        //BERT Specific
        uint32_t *token_type_id_buffer = (uint32_t *) getBuffer(t_token_type_ids);
        std::memset(token_type_id_buffer, 0, n_tokens * sizeof(uint32_t));
    }

    // Setup position IDs
    if (m_positional_encoding.type == PositionalEncoding::ROPE) {
        // clang-format off
        switch (d_pos) {
        case QNN_DATATYPE_UFIXED_POINT_8:
            setupRopePositionEmbedding<uint8_t>(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
        case QNN_DATATYPE_UFIXED_POINT_16:
            setupRopePositionEmbedding<uint16_t>(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
        case QNN_DATATYPE_FLOAT_16:
            setupRopePositionEmbeddingFP16(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
        default: __ERROR("Unsupported rope position dtype {}", d_pos.str()); return false;
        }
        // clang-format on
    } else if (m_positional_encoding.type == PositionalEncoding::ABSOLUTE) {
        uint32_t* position_id_buffer = (uint32_t*)getBuffer(t_position_ids);
        std::memset(position_id_buffer, 0, n_tokens * sizeof(uint32_t));

        // Fill up position_ids buffer
        uint32_t* pos_id_start = &position_id_buffer[pad_left ? n_tokens - n_inputs : 0];
        uint32_t* pos_id_end   = pos_id_start + n_inputs;
        std::iota(pos_id_start, pos_id_end, n_past);
    } else if (m_positional_encoding.type == PositionalEncoding::ALIBI) {
        setupAlibiPositionEmbedding<int32_t>(pad_left, n_tokens, n_inputs, n_past);
    }

    __TRACE("qnn-htp: setup-input-tensors complete : {} usec", start.elapsed_usec());
    return true;
}


bool QnnNspModel::setupInputTensors(
        std::span<uint8_t>       embedding,
        int32_t                  n_past,
        std::span<const int32_t> attention_map,
        size_t                   n_skip_prefix,
        size_t                   n_apply_prefix_offset
) {
    qualla::Timer start;

    const int     n_tokens = run_info.n_tokens;
    const int     n_inputs = run_info.n_processed;
    const int32_t n_valid  = n_past + n_inputs;
    __TRACE("qnn-htp: setup-input-tensors with {} tokens for AR-{}", n_inputs, n_tokens);

    const bool pad_left = (n_tokens == m_ctx_size);
    if (n_inputs > n_tokens) {
        __ERROR("qnn-htp: setup-input-tensors too many tokens: {} on AR-{}", n_inputs, n_tokens);
        return false;
    }

    // Setup input embeds tensor
    {
        // Quantize and fill, don't make double copy
        size_t in_buf_offset = pad_left ? n_tokens - n_inputs : 0;
        size_t startIdx = pad_left ? 0 : n_inputs;
        size_t endIdx = pad_left ? in_buf_offset : n_tokens;

        if (embedding_datatype == "float32") {
            // First flush the buffer with eos token embedding
            for (size_t i = startIdx; i < endIdx; i++) {
                quantizeInput((float*)m_eosEmbedding.data(), i*m_embd_size, m_embd_size);
            }

            // Quantize the data input vector
            quantizeInput((float*)embedding.data(), in_buf_offset*m_embd_size, n_inputs * m_embd_size);
        } else if (embedding_datatype == "native") {
            // Size of the buffer for one embedding vector.
            const size_t embedBufSize = m_embeddingBufferSize;
            // First flush the buffer with eos token embedding
            uint8_t* embeddingSrc = static_cast<uint8_t*>(m_eosEmbedding.data());
            for (size_t i = startIdx; i < endIdx; i++) {
                std::copy(embeddingSrc, embeddingSrc + embedBufSize, (uint8_t*)getBuffer(t_input_ids) + i*embedBufSize);
            }

            // Copy the data input vector
            embeddingSrc = static_cast<uint8_t*>(embedding.data());
            std::copy(embeddingSrc, embeddingSrc + embedding.size(), (uint8_t*)getBuffer(t_input_ids) + in_buf_offset*embedBufSize);
        }
    }

    // Don't modify attention mask it should work out of the box
    // clang-format off
    switch (d_attn_map) {
        case QNN_DATATYPE_UFIXED_POINT_8:
            setupAttentionMask<uint8_t>(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
        case QNN_DATATYPE_UFIXED_POINT_16:
            setupAttentionMask<uint16_t>(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
        case QNN_DATATYPE_INT_32:
            setupAttentionMask<int32_t>(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
        case QNN_DATATYPE_FLOAT_16: {
            setupAttentionMaskFP16(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix,
                                   n_apply_prefix_offset);
            break;
        }
        default: __ERROR("Unsupported attention mask dtype {}", d_attn_map.str()); return false;
    }
    // clang-format on

    // Setup token type IDs // Will not be
    if(m_modelArchitectureType == ModelArchitectureType::ENCODER) {
        //BERT Specific
        uint32_t *token_type_id_buffer = (uint32_t *) getBuffer(t_token_type_ids);
        std::memset(token_type_id_buffer, 0, n_tokens * sizeof(uint32_t));
    }

    // Setup position IDs
    if (m_positional_encoding.type == PositionalEncoding::ROPE) {
        // clang-format off
        switch (d_pos) {
            case QNN_DATATYPE_UFIXED_POINT_8:
                setupRopePositionEmbedding<uint8_t>(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
            case QNN_DATATYPE_UFIXED_POINT_16:
                setupRopePositionEmbedding<uint16_t>(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
            case QNN_DATATYPE_FLOAT_16:
                setupRopePositionEmbeddingFP16(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break;
            default: __ERROR("Unsupported rope position dtype {}", d_pos.str()); return false;
        }
        // clang-format on
    } else if (m_positional_encoding.type == PositionalEncoding::ABSOLUTE) {
        uint32_t* position_id_buffer = (uint32_t*)getBuffer(t_position_ids);
        std::memset(position_id_buffer, 0, n_tokens * sizeof(uint32_t));

        // Fill up position_ids buffer
        uint32_t* pos_id_start = &position_id_buffer[pad_left ? n_tokens - n_inputs : 0];
        uint32_t* pos_id_end   = pos_id_start + n_inputs;
        std::iota(pos_id_start, pos_id_end, n_past);
    }

    __TRACE("qnn-htp: setup-input-tensors complete : {} usec", start.elapsed_usec());
    return true;
}

bool QnnNspModel::runInferenceHelper(bool pipeline, int32_t* total_wait, int32_t* total_exec) {
    // run_info is set in runInference
    int32_t idx                  = 0;
    int32_t wait_kv_update_count = _kv_update_count;

    auto [variant, n_processed, tokens] = run_info; // based on type one of the embedding and token vector will be empty.
    for (auto& nsp_graph : m_nsp_graphs) {
        //__DEBUG("execute({}, {}, {})", variant, m_inference_count, wait_kv_update_count);
        if (!nsp_graph.execute(variant, m_inference_count, wait_kv_update_count)) return false;
        auto [cur_wait, cur_exec] = nsp_graph.getExecutionStats();

        // If we are pipelining execution with KV$Update, dispatch KV$ update jobs
        if (pipeline) {
            qualla::Timer timer;

            int32_t n_past   = static_cast<int32_t>(m_nPast + n_processed);
            if(!m_disableKvCache)
                _kv_update_count = _kv_dispatcher->dispatch(idx, variant, n_past);
            cur_wait += timer.elapsed_usec();
        }

        *total_exec += cur_exec;
        *total_wait += cur_wait;
        idx++;
    }

    if (pipeline) {
        if(m_inputType == InputType::TOKENS) // used tokens for processing, save them
            token_history.insert(token_history.end(), &tokens[0], &tokens[n_processed]);
        else if(m_inputType == InputType::UNKNOWN)
        {
            __ERROR("Unknown input type found");
            return false;
        }
        m_nPast += n_processed;
    }

    if (_debug_outputs){
        if(m_modelArchitectureType == ModelArchitectureType::ENCODER){
            if(!debugOutputs(m_nsp_graphs.back().variants[run_info.n_tokens]->getOutput(m_layerNames[LayerType::POOL_OUTPUT]), m_layerNames[LayerType::POOL_OUTPUT])){
                __DEBUG("qnn-htp : Failed to save {} tensor", m_layerNames[LayerType::POOL_OUTPUT]);
            }
            if(!debugOutputs(m_nsp_graphs.back().variants[run_info.n_tokens]->getOutput(m_layerNames[LayerType::SEQ_OUTPUT]), m_layerNames[LayerType::SEQ_OUTPUT])){
                __DEBUG("qnn-htp : Failed to save {} tensor", m_layerNames[LayerType::SEQ_OUTPUT]);
            }
        }
        else {
            if(!debugOutputs(m_nsp_graphs.back().variants[variant]->getOutput(m_layerNames[LayerType::OUTPUT]), m_layerNames[LayerType::OUTPUT])) {
                __DEBUG("qnn-htp : Failed to save {} tensor", m_layerNames[LayerType::OUTPUT]);
            }
        }
    }

    m_inference_count++;
    return true;
}

bool QnnNspModel::debugOutputs(QnnUtils::Tensor* outTensor, std::string& outTensorName){

    if(outTensor == NULL){
        __DEBUG("qnn-htp : Encountered NULL Tensor");
        return false;
    }

    auto [variant, n_processed, tokens] = run_info;

    int output_bw = outTensor->dtype.bw(); // Detect 8-bit vs 16-bit logits
    uint8_t *output_buffer = (uint8_t *) getBuffer(outTensor);

    int32_t offset = (variant == m_ctx_size) ? (m_ctx_size - n_processed) : 0;
    int32_t bufsize = 0;
    if(m_modelArchitectureType == ModelArchitectureType::ENCODER){
        bufsize = m_ctx_size * m_embd_size * output_bw; // ctx * embed_size * output_bitwidth
        // Bert is saving complete out buffer as it is.
    }
    else{
        // Reducing buffer to number of processed tokens and each token is of vocab_size
        bufsize = n_processed * m_vocab_size * output_bw; // processed_token * vocab_size * output_bitwidth
        output_buffer += offset * m_vocab_size * output_bw; // shift output buffer to  offset * vocab_size * output_bitwidth
    }

    std::string fname = fmt::format("{}/{}/{:03d}", _debug_path, outTensorName, m_inference_count);
    QnnUtils::writeRawData(output_buffer, bufsize, fname);
    return true;

}

int32_t QnnNspModel::selectVariantStrategy(int32_t n_inputs, int32_t n_past, int32_t cur_variant) {
    int32_t best_variant = cur_variant;
    int32_t best_cost    = INT32_MAX;
    int32_t switch_cost  = 10; // Currently hard-coded to 10ms

    for (auto [variant, latency] : variant_latency) {
        // If variant cannot support the n_past, it is a non-starter
        // e.g. AR-128 with ctx_size=1024 can only support upto n_past=896 since it uses 128 output
        if (n_past + n_inputs > m_ctx_size) continue;

        const int32_t n_iters = 1 + ((n_inputs - 1) / variant);
        const int32_t cost    = latency * n_iters + ((variant == cur_variant) ? 0 : switch_cost);
        if (cost < best_cost) {
            best_variant = variant;
            best_cost    = cost;
        }
    }

    __DEBUG("qnn-htp : Variant selected AR={} (~ {} ms)", best_variant, best_cost);
    return best_variant;
}

size_t QnnNspModel::runInference(
        const std::vector<int32_t>& in_tokens,
        const std::vector<int32_t>& attention_map,
        std::vector<float>&         output,
        bool                        output_all
) {
    qualla::Timer start;

    __TRACE("runInference logits_all={} in_tokens={}", output_all, in_tokens);

    if(m_inputType != InputType::TOKENS) {
        throw std::runtime_error("Wrong Type of input is supplied for token type query.");
    }

    if (in_tokens.size() == 0) return 0;

    // Select variant based on variant_latency, or default to current variant
    std::vector<int32_t> tokens(in_tokens);
    if (!variant_latency.empty() && !m_disableKvCache) {
        const int32_t cur_variant = _kv_dispatcher->getCurVariant();
        const int32_t new_variant = selectVariantStrategy(tokens.size(), m_nPast, cur_variant);
        if (cur_variant != new_variant) // Switch variant if necessary
            _kv_update_count = _kv_dispatcher->dispatch(new_variant, m_nPast);
    }

    // If variant selected in BERT-Mode, append token history to current request
    int32_t variant = 0;
    if(!m_disableKvCache)
        variant = _kv_dispatcher->getCurVariant();
    else
        variant = nsp_graph_count.rbegin()->first; // pick largest variant
    if (variant == m_ctx_size && m_nPast != 0)
        tokens.insert(tokens.begin(), token_history.begin(), token_history.end());

    const int32_t n_inputs = static_cast<int32_t>(tokens.size());
    const int32_t n_past   = static_cast<int32_t>(m_nPast);
    const int32_t n_valid  = n_past + n_inputs;
    run_info.n_tokens      = variant;
    if (variant != m_ctx_size && m_nPast + variant > m_ctx_size) {
        __ERROR("qnn-htp: exceeding ctx_size! : {} + {} > {}", m_nPast, variant, m_ctx_size);
        return 0;
    }

    // Calculate number of batches for run-inference
    const int32_t num_iters = 1 + ((n_inputs - 1) / variant);
    __DEBUG("qnn-htp: run-inference : {} tokens (AR-{} * {} iters)", n_inputs, variant, num_iters);

    // Validate attention_map size
    if (!attention_map.empty() && attention_map.size() != n_inputs &&
        attention_map.size() != n_inputs * (n_past + n_inputs)) {
        // clang-format off
        __ERROR("qnn-htp: attention_map must be 1D(n_inputs) or 2D(n_inputs * (n_past + n_inputs))"
                "but has size={} for n_past={} n_inputs={}", attention_map.size(), n_past, n_inputs);
        // clang-format on
        return 0;
    }
    std::vector<int32_t> chunked_attn_map;

    // Technical note: int32_t can hold upto 596 hours
    // Even int16_t should be sufficient here - it holds upto 32.8 seconds
    int32_t total_wait = 0;
    int32_t total_exec = 0;

    // user choice overwrites the default behaviour in case of Embedding models
    if(m_modelArchitectureType == ModelArchitectureType::ENCODER)
        output_all = !m_pooled_output;

    // Reset logit accumulator
    size_t output_count = output_all ? n_inputs : 1; // actual number of logits

    if(m_modelArchitectureType == ModelArchitectureType::ENCODER)
        output.resize(output_count * m_embd_size);
    else
        output.resize(output_count * m_vocab_size);

    for (int i = 0; i < num_iters; i++) {
        const int32_t update_size = std::min(variant, n_inputs - i * variant);
        run_info.n_processed      = update_size;
        run_info.tokens.assign(&tokens[i * variant], &tokens[i * variant + update_size]);

        int32_t n_skip_prefix =
                (i * variant < _offset_to_apply_kv_prefix) ? _size_to_skip_kv_prefix : 0;
        int32_t n_apply_prefix_offset = 0;
        if (i * variant < _offset_to_apply_kv_prefix)
            n_apply_prefix_offset = std::min(variant, _offset_to_apply_kv_prefix - i * variant);

        // Chunk inputs and attention mask
        std::span<int32_t> tokens_chunk   = std::span{tokens.data(),tokens.size()}.subspan(i * variant, update_size);
        std::span<int32_t> attn_map_chunk = std::span<int32_t>();
        if (attention_map.size() == n_inputs) {
            chunked_attn_map.resize(update_size);
            // Take exactly update_size elements. Be mindful to decrease offset already processed
            for (int j = 0; j < update_size; j++)
                chunked_attn_map[j] = attention_map[i * variant + j] - (i * variant);
            attn_map_chunk = std::span{chunked_attn_map.data(),chunked_attn_map.size()};
        } else if (attention_map.size() == n_inputs * (n_past + n_inputs)) {
            chunked_attn_map.clear();
            chunked_attn_map.resize(update_size * (m_nPast + update_size));

            for (int j = 0; j < update_size; j++) {
                // Be mindful. m_nPast changes each iteration.
                // n_tokens is total #tokens called. update_size is the n_tokens for this iteration
                // n_past is the initial m_nPast. n_valid = n_past + n_tokens
                std::memcpy(
                        &chunked_attn_map[j * (m_nPast + update_size)],
                        &attention_map[i * variant * n_valid + j * n_valid],
                        (m_nPast + update_size) * sizeof(int32_t)
                );
            }
            attn_map_chunk = std::span{chunked_attn_map.data(),chunked_attn_map.size()};
        }

        if (!setupInputTensors(
                    tokens_chunk,
                    (variant == m_ctx_size) ? 0 : m_nPast,
                    attn_map_chunk,
                    n_skip_prefix,
                    n_apply_prefix_offset
            ))
            return 0;

        // Run Inference and pipeline KV$ update iff n_inputs is exactly 1 or we have more batches
        bool pipeline = (n_inputs == 1 || i < num_iters - 1);
        if (!runInferenceHelper(pipeline, &total_wait, &total_exec)) return 0;

        if (m_modelArchitectureType != ModelArchitectureType::ENCODER && output_all) {
            // Accumulate logits
            const size_t logit_offset = i * variant * m_vocab_size;
            const size_t logit_count = update_size * m_vocab_size;
            getDequantLogits(std::span{output.data(), output.size()}.subspan(logit_offset, logit_count),
                             output_all);
        }
    }

    // Return last logit if not accumulating
    if(m_modelArchitectureType != ModelArchitectureType::ENCODER) {
        if(!output_all)
            getDequantLogits(std::span{output.data(), output.size()}, output_all);
    }
    else
        getEmbeddings(std::span{output.data(), output.size()});

    __DEBUG("qnn-htp: run-inference complete : {} usec : wait {} exec {}",
            start.elapsed_usec(),
            total_wait,
            total_exec);

    // threadpool.suspend();
    return output_count;
}

bool QnnNspModel::quantizeInput(float* in, size_t tensorOffset ,size_t length) {

    if(t_input_ids  == nullptr) {
        __ERROR("Input Tensor {} not found during execute", m_layerNames[LayerType::INPUT]);
        return false;
    }

    const auto scale = t_input_ids->quantParam[0].scale;
    const auto offset = t_input_ids->quantParam[0].offset;

    // clang-format off
    switch (t_input_ids->dtype) {
        case QNN_DATATYPE_UFIXED_POINT_8: QnnUtils::quantizeTensorPtr(in, (uint8_t*)getBuffer(t_input_ids) + tensorOffset, offset, scale, length); break;
        case QNN_DATATYPE_UFIXED_POINT_16: QnnUtils::quantizeTensorPtr(in, (uint16_t*)getBuffer(t_input_ids) + tensorOffset, offset, scale, length); break;
        default: __ERROR("Unsupported alpha tensor dtype {}", t_input_ids->dtype.str()); return false;
    }

    return true;
}

size_t QnnNspModel::getEmbeddingBufferSize() {
    return m_embeddingBufferSize;
}

size_t QnnNspModel::runInference(
        std::vector<uint8_t>&       embedding,
        const std::vector<int32_t>& attention_map,
        std::vector<float>&         output,
        bool                        output_all
) {
    qualla::Timer start;

    __DEBUG("qnn-htp: run-inference start : n_Embd {}", embedding.size());

    if(m_inputType != InputType::EMBEDDINGS) {
        throw std::runtime_error("Embedding input type is not supported by the model.");
    }

    if (embedding.size() == 0) return true;

    size_t embedBufSize = m_embeddingBufferSize;
    // Select variant based on variant_latency, or default to current variant
    int32_t curTokenCount = embedding.size() / embedBufSize;
    if (!variant_latency.empty() && !m_disableKvCache) {
        const int32_t cur_variant = _kv_dispatcher->getCurVariant();
        const int32_t new_variant = selectVariantStrategy(curTokenCount, m_nPast, cur_variant);
        if (cur_variant != new_variant) // Switch variant if necessary
            _kv_update_count = _kv_dispatcher->dispatch(new_variant, m_nPast);
    }

    // If variant selected in BERT-Mode, append token history to current request
    const int32_t variant = _kv_dispatcher->getCurVariant();

    // We will never be maintaining history for the embedding

    const int32_t n_inputs = static_cast<int32_t>(curTokenCount);
    const int32_t n_past   = static_cast<int32_t>(m_nPast);
    const int32_t n_valid  = n_past + n_inputs;
    run_info.n_tokens = variant;

    if (variant != m_ctx_size && m_nPast + variant > m_ctx_size) {
        __ERROR("qnn-htp: exceeding ctx_size! : {} + {} > {}", m_nPast, variant, m_ctx_size);
        return 0;
    }

    const int32_t num_iters = 1 + ((n_inputs - 1) / variant);
    __DEBUG("qnn-htp: run-inference : {} tokens (AR-{} * {} iters)",
            n_inputs,
            variant,
            num_iters);

    // Validate attention_map size
    if (!attention_map.empty() && attention_map.size() != n_inputs &&
        attention_map.size() != n_inputs * (n_past + n_inputs)) {
        // clang-format off
        __ERROR("qnn-htp: attention_map must be 1D(n_inputs) or 2D(n_inputs * (n_past + n_inputs))"
                "but has size={} for n_past={} n_inputs={}", attention_map.size(), n_past, n_inputs);
        // clang-format on
        return 0;
    }
    std::vector<int32_t> chunked_attn_map;

    // Technical note: int32_t can hold upto 596 hours
    // Even int16_t should be sufficient here - it holds upto 32.8 seconds
    int32_t total_wait = 0;
    int32_t total_exec = 0;

    // Reset logit accumulator
    size_t output_count = output_all ? n_inputs : 1; // actual number of logits

    output.resize(output_count * m_vocab_size);

    for (int i = 0; i < num_iters; i++) {
        const int32_t update_size = std::min(variant, n_inputs - i * variant);
        run_info.n_processed      = update_size;
        const int32_t startIdx    = i * variant * embedBufSize;

        int32_t n_skip_prefix =
                (i * variant < _offset_to_apply_kv_prefix) ? _size_to_skip_kv_prefix : 0;
        int32_t n_apply_prefix_offset = 0;
        if (i * variant < _offset_to_apply_kv_prefix)
            n_apply_prefix_offset = std::min(variant, _offset_to_apply_kv_prefix - i * variant);

        // Chunk inputs and attention mask
        std::span<uint8_t> embedding_chunk  = std::span{embedding.data(),embedding.size()}.subspan(startIdx, update_size*embedBufSize);
        std::span<int32_t> attn_map_chunk = std::span<int32_t>();
        if (attention_map.size() == n_inputs) {
            chunked_attn_map.resize(update_size);
            // Take exactly update_size elements. Be mindful to decrease offset already processed
            for (int j = 0; j < update_size; j++)
                chunked_attn_map[j] = attention_map[i * variant + j] - (i * variant);
            attn_map_chunk = std::span{chunked_attn_map.data(),chunked_attn_map.size()};
        } else if (attention_map.size() == n_inputs * (n_past + n_inputs)) {
            chunked_attn_map.clear();
            chunked_attn_map.resize(update_size * (m_nPast + update_size));

            for (int j = 0; j < update_size; j++) {
                // Be mindful. m_nPast changes each iteration.
                // n_tokens is total #tokens called. update_size is the n_tokens for this iteration
                // n_past is the initial m_nPast. n_valid = n_past + n_tokens
                std::memcpy(
                        &chunked_attn_map[j * (m_nPast + update_size)],
                        &attention_map[i * variant * n_valid + j * n_valid],
                        (m_nPast + update_size) * sizeof(int32_t)
                );
            }
            attn_map_chunk = std::span{chunked_attn_map.data(),chunked_attn_map.size()};
        }

        if (!setupInputTensors(
                embedding_chunk,
                (variant == m_ctx_size) ? 0 : m_nPast,
                attn_map_chunk,
                n_skip_prefix,
                n_apply_prefix_offset
        ))
            return 0;

        // Run Inference and pipeline KV$ update iff n_inputs is exactly 1 or we have more batches
        bool pipeline = (n_inputs == 1 || i < num_iters - 1);
        if (!runInferenceHelper(pipeline, &total_wait, &total_exec)) return 0;

        if (output_all) {
            // Accumulate logits
            const size_t logit_offset = i * variant * m_vocab_size;
            const size_t logit_count = update_size * m_vocab_size;
            getDequantLogits(std::span{output.data(), output.size()}.subspan(logit_offset, logit_count),
                             output_all);
        }
    }

    // Return last logit if not accumulating
    if(!output_all)
        getDequantLogits(std::span{output.data(), output.size()}, output_all);

    __DEBUG("qnn-htp: run-inference complete : {} usec : wait {} exec {}",
            start.elapsed_usec(),
            total_wait,
            total_exec);

    return output_count;
}

bool QnnNspModel::cacheEosEmbedding(std::vector<uint8_t>& eosEmbedding) {
    m_eosEmbedding = eosEmbedding;
    return true;
}

bool QnnNspModel::setKVCacheNPast(size_t n_past, const std::vector<bool>& selected) {
    __TRACE("setKVCacheNPast (m_nPast={} -> n_past={})", m_nPast, n_past);
    if (n_past == m_nPast && n_past != 0) return true;

    if (m_nPast + run_info.n_processed < n_past) {
        __ERROR("qnn-htp: set-kv n_past update larger than number of processed tokens : n_past {} n_proc {}",
                n_past,
                m_nPast + run_info.n_processed);
        return false;
    }

    if (m_inputType == InputType::TOKENS) {
        if (n_past == 0) {
            int32_t new_variant = nsp_graph_count.rbegin()->first;
            _kv_update_count = _kv_dispatcher->dispatch(new_variant, 0, selected);
            token_history.clear();

        } else if (n_past < m_nPast) {
            auto [variant, update_size, tokens] = run_info;
            _kv_update_count = _kv_dispatcher->dispatch(variant, n_past);
            token_history.resize(n_past);
        } else {
            int32_t new_variant = nsp_graph_count.begin()->first;
            _kv_update_count = _kv_dispatcher->dispatch(new_variant, n_past, selected);

            auto [variant, update_size, tokens] = run_info;

            if (variant == m_ctx_size) {
                token_history.assign(&tokens[0], &tokens[n_past]);
            } else if (selected.empty()) {
                token_history.insert(token_history.end(), &tokens[0], &tokens[n_past - m_nPast]);
            } else {
                for (auto i = 0; i < tokens.size(); ++i) {
                    if (selected[i]) token_history.push_back(tokens[i]);
                }
            }
        }
    }
    else if (m_inputType == InputType::EMBEDDINGS) { // Don't add embedding history, It is costly maintenance to do.
        if (n_past == 0) {
            int32_t new_variant = nsp_graph_count.rbegin()->first;
            _kv_update_count = _kv_dispatcher->dispatch(new_variant, 0, selected);
        } else if (n_past < m_nPast) {
            auto [variant, update_size, tokens] = run_info;
            _kv_update_count = _kv_dispatcher->dispatch(variant, n_past);
        } else {
            int32_t new_variant = nsp_graph_count.begin()->first;
            _kv_update_count = _kv_dispatcher->dispatch(new_variant, n_past, selected);
        }
    }
    else
    {
        __ERROR("Wrong type of input is found.");
        return false;
    }

    m_nPast = n_past;
    return true;
}

template <typename U, typename T>
inline void deQuantizeOutputs(
        U*            inputs,
        std::span<T>& outputs,
        const double  scale,
        const int32_t offset,
        const int     count
) {
#pragma clang loop vectorize(enable) interleave(enable)
    for (int i = 0; i < count; ++i)
        outputs[i] = ((T)inputs[i] + offset) * scale;
}

template <typename U, typename T>
inline void castOutputs(U* inputs, std::span<T>& outputs, const int numElements, const int bitWidth) {
    if(bitWidth == 2) {
#pragma clang loop vectorize(enable) interleave(enable)
        for (int i = 0; i < numElements; ++i)
            outputs[i] = fp16_ieee_to_fp32_value(inputs[i]);
    }
    else if(bitWidth == 4) {
#pragma clang loop vectorize(enable) interleave(enable)
        for (size_t i = 0; i < numElements; i++) {
            outputs[i] = inputs[i];
        }
    }
}

size_t QnnNspModel::getDequantLogits(std::span<float> dequant_logits, bool logits_all) {
    qualla::Timer start;

    QnnUtils::Tensor* const logit_spec =
            m_nsp_graphs.back().variants[run_info.n_tokens]->getOutput(m_layerNames[LayerType::OUTPUT]);
    const int return_size      = logits_all ? run_info.n_processed : 1;
    const auto [scale, offset] = logit_spec->quantParam[0];

    auto d_logits = QnnUtils::DataType(logit_spec->tensor);

    int logit_bw = logit_spec->dtype.bw();

    uint8_t* logit_buffer = (uint8_t*)getBuffer(logit_spec);
    // const int return_size = logits_all ? run_info.n_processed : 1;
    if (logit_spec->dims.getNumElements() == m_vocab_size) {
        // BERT Mode graph may return only the last logit
        // If only one logit is returned, simply return the last logit
        if (return_size > 1)
            throw std::runtime_error("Requested all logits, but graph only produces one logit");
    } else {
        // If multiple logits are returned, offset to the correct location in the buffer
        if (run_info.n_tokens == m_ctx_size) {
            // This was left-padded, logits are at [n_tokens - n_processed, n_tokens]
            logit_buffer += (run_info.n_tokens - return_size) * m_vocab_size * d_logits.bw();
        } else if (logits_all == false) {
            // This was right-padded, logits are at indexes [0, n_processed]
            logit_buffer += (run_info.n_processed - 1) * m_vocab_size * d_logits.bw();
        }
    }
    const int n_logits = static_cast<int>(m_vocab_size * return_size);
    __TRACE("qnn-htp: get-logits logits_all={} for {} tokens. Returning {}*{}",
            logits_all,
            run_info.n_processed,
            return_size,
            m_vocab_size);

    switch (d_logits) {
    case QNN_DATATYPE_UFIXED_POINT_8:
        deQuantizeOutputs((uint8_t*)logit_buffer, dequant_logits, scale, offset, n_logits);
        break;
    case QNN_DATATYPE_UFIXED_POINT_16:
        deQuantizeOutputs((uint16_t*)logit_buffer, dequant_logits, scale, offset, n_logits);
        break;
    case QNN_DATATYPE_FLOAT_16: {
        castOutputs((uint16_t*)logit_buffer, dequant_logits, n_logits, logit_bw);
        break;
    }
    default:
        __ERROR("Unsupported logits dtype {}", d_logits.str());
    }

    __DEBUG("qnn-htp: getDequantLogits complete : {} usec (return_size={})",
            start.elapsed_usec(),
            return_size);
    return return_size;
}

bool QnnNspModel::calculate_rope_embeddings(void) {
    if (m_positional_encoding.type != PositionalEncoding::ROPE) return true;

    const size_t nmemb  = m_ctx_size * m_pos_dim;
    const int    pos_bw = d_pos.bw();

    const double             theta        = m_positional_encoding.rope_params.theta;
    const RopeScalingParams& rope_scaling = m_positional_encoding.rope_params.rope_scaling;

    rope_sin = malloc(nmemb * pos_bw);
    rope_cos = malloc(nmemb * pos_bw);

    auto [q_scale, q_offset] = t_position_ids_cos->quantParam[0];
    if (d_pos == QNN_DATATYPE_FLOAT_16) { // If floating point, don't quantize!
        q_scale  = 1.0;
        q_offset = 0;
    }

    // Calculate inv_freq array
    std::vector<double> inv_freq(m_pos_dim);
    const double        exponent = 1.0 / static_cast<double>(m_pos_dim);
    for (int j = 0; j < m_pos_dim; j++)
        inv_freq[j] = 1.0 / pow(theta, j * exponent);
    double attention_factor = 1.0;
    if (rope_scaling.rope_type == RopeScalingParams::ROPE_LLAMA3) {
        // Implemented from HuggingFace
        // https://github.com/huggingface/transformers/blob/47c29ccfaf56947d845971a439cbe75a764b63d7/src/transformers/modeling_rope_utils.py#L298
        const double& factor           = rope_scaling.llama3_params.factor;
        const double& low_freq_factor  = rope_scaling.llama3_params.low_freq_factor;
        const double& high_freq_factor = rope_scaling.llama3_params.high_freq_factor;
        const int&    old_context_len = rope_scaling.llama3_params.original_max_position_embeddings;

        const double low_freq_wavelen  = old_context_len / low_freq_factor;
        const double high_freq_wavelen = old_context_len / high_freq_factor;

        for (int j = 0; j < m_pos_dim; j++) {
            const double wavelen = 2 * M_PI / inv_freq[j];
            if (wavelen < high_freq_wavelen) // wavelen < high_freq_wavelen: do nothing
                continue;
            else if (wavelen > low_freq_wavelen) // wavelen > low_freq_wavelen: divide by factor
                inv_freq[j] = 1.0 / static_cast<double>(factor * pow(theta, j * exponent));
            else { // otherwise: interpolate between the two, using a smooth factor
                assert(low_freq_wavelen != high_freq_wavelen);
                const double smooth =
                        (static_cast<double>(old_context_len) / wavelen - low_freq_factor) /
                        (high_freq_factor - low_freq_factor);
                inv_freq[j] = ((1 - smooth) * inv_freq[j] / factor + smooth * inv_freq[j]);
            }
        }
    } else if (rope_scaling.rope_type == RopeScalingParams::ROPE_LONGROPE) {
        // Validate factor >= 1.0, len(long_factor) == rope-dim and len(short_factor) == rope-dim
        const double& factor       = rope_scaling.longrope_params.factor;
        const int& old_context_len = rope_scaling.longrope_params.original_max_position_embeddings;

        const auto& inv_factors = (m_ctx_size > old_context_len)
                                          ? rope_scaling.longrope_params.long_factor
                                          : rope_scaling.longrope_params.short_factor;

        if (inv_factors.size() != m_pos_dim)
            throw std::runtime_error(fmt::format(
                    "long-factor (len={}) and short-factor (len={}) must have length rope-dim={}",
                    rope_scaling.longrope_params.long_factor.size(),
                    rope_scaling.longrope_params.short_factor.size(),
                    m_pos_dim
            ));

        for (int j = 0; j < m_pos_dim; j++)
            inv_freq[j] = inv_freq[j] / inv_factors[j];

        attention_factor =
                std::sqrt(1.0 + std::log(factor) / std::log(static_cast<double>(old_context_len)));
    }
    for (int i = 0; i < m_ctx_size; i++) {
        for (int j = 0; j < m_pos_dim; j++) {
            const double freq = i * inv_freq[j];

            const double sin_val = ((sin(freq) * attention_factor) / q_scale) - q_offset;
            const double cos_val = ((cos(freq) * attention_factor) / q_scale) - q_offset;

            // round() instead of floor() seems to produce an acuracy drop. To debug later
            switch (d_pos) {
            case QNN_DATATYPE_UFIXED_POINT_8:
                ((uint8_t*)rope_sin)[i * m_pos_dim + j] = static_cast<uint8_t>(sin_val);
                ((uint8_t*)rope_cos)[i * m_pos_dim + j] = static_cast<uint8_t>(cos_val);
                break;
            case QNN_DATATYPE_UFIXED_POINT_16:
                ((uint16_t*)rope_sin)[i * m_pos_dim + j] = static_cast<uint16_t>(sin_val);
                ((uint16_t*)rope_cos)[i * m_pos_dim + j] = static_cast<uint16_t>(cos_val);
                break;
            case QNN_DATATYPE_FLOAT_16:
                ((uint16_t *)rope_sin)[i * m_pos_dim + j] = fp16_ieee_from_fp32_value(sin_val);
                ((uint16_t*)rope_cos)[i * m_pos_dim + j] = fp16_ieee_from_fp32_value(cos_val);
                break;
            default:
                __ERROR("Unsupported position ids datatype {}", d_pos.str());
                return false;
            }
        }
    }

    if (_debug_tensors) {
        std::string dtype =
                fmt::format("{}", (d_pos == QNN_DATATYPE_FLOAT_16) ? "f" : "u", pos_bw * 8);
        std::string fname_sin = fmt::format("{}/position_ids_sin.{}.dat", _debug_path, pos_bw * 8);
        std::string fname_cos = fmt::format("{}/position_ids_cos.{}.dat", _debug_path, pos_bw * 8);
        QnnUtils::writeRawData(rope_sin, nmemb * pos_bw, fname_sin);
        QnnUtils::writeRawData(rope_cos, nmemb * pos_bw, fname_cos);
    }

    return true;
}

bool QnnNspModel::load_lmhead_weight_as_input(void) {
    if (!_lmhead_weight_input) return true;
    if (_lmhead_weight_input && lmhead_weight_dir.empty()) {
        __ERROR("NSPModel: LMhead weight file not found");
        return false;
    }
    for (auto& variant : m_variant_list) {
        for (auto& [tname, tspec] : variant.input_specs) {
            if (tname.compare("weight") == 0) {
                // weight tensor file name should be in same format as tensor name present in graph
                std::string weight_file =
                        (model_basedir / fs::path(lmhead_weight_dir) / fs::path(tname + ".raw"))
                                .string();

                QnnUtils::Dims dims        = tspec.dims;
                size_t         numElements = dims.getNumElements();

                size_t             size = sizeof(float);
                std::vector<float> weight_f32; // Temporary variable to load fp32 values
                weight_f32.reserve(numElements);

                FILE* fp = fopen(weight_file.c_str(), "r");
                if (fp == NULL) {
                    __ERROR("NSPModel: Error opening file: {}", weight_file);
                    return false;
                }

                size_t count = fread(weight_f32.data(), size, numElements, fp);
                fclose(fp);

                if (count != numElements) {
                    __ERROR("NSPModel: Could not load {} - expected file size {}",
                            weight_file,
                            numElements * size);
                    return false;
                }

                int8_t* weight_buffer = (int8_t*)getBuffer(tspec);
                // Quantize the values, per width quantization
                QnnUtils::perWidthQuantizeTensorPtr(
                        weight_f32.data(),
                        weight_buffer,
                        tspec.quantParam,
                        dims.height,
                        dims.width,
                        dims.channel
                );
            }
        }
    }
    return true;
}

bool QnnNspModel::flushLoraWeightsBuffers(void){
    if(!_lora_enabled){
        __ERROR("qnn-htp: Model does not support LoRA weights.");
        return false;
    }

    for (auto& variant : m_variant_list) {
        for (auto& [tname, tspec] : variant.input_specs) {
            if (tname.find("lora") != std::string::npos) { // find lora weights tensors and flush them out
                if(getBuffer(tspec) == nullptr)
                    return false;
                size_t numElements = tspec.dims.getNumElements();
                auto   offset      = tspec.quantParam[0].offset;
                // Since values needs to be quantized so zero is going to get translated.
                // clang-format off
                switch (tspec.dtype) {
                    case QNN_DATATYPE_UFIXED_POINT_8:  std::fill_n((uint8_t*)getBuffer(tspec), numElements, static_cast<uint8_t>(-offset));  break;
                    case QNN_DATATYPE_UFIXED_POINT_16: std::fill_n((uint16_t*)getBuffer(tspec), numElements, static_cast<uint16_t>(-offset));  break;
                    case QNN_DATATYPE_FLOAT_16:{
                        uint16_t *buffer = (uint16_t *)getBuffer(tspec);
                        for(int i=0;i<numElements;i++){
                            buffer[i] = fp16_ieee_from_fp32_value(-offset);
                        }
                        break;
                    }
                    default: __ERROR("Unsupported {} datatype for {} tensor", tspec.dtype.str(), tname); return false;
                }
            }
        }
    }
    return true;
}

bool QnnNspModel::applyLoraWeights(const std::string& lora_weights_name){
    if(!_lora_enabled){
        __ERROR("qnn-htp: Model does not support LoRA weights.");
        return false;
    }
    if (lora_conf != LoraConfigType::LORA_INPUT_WEIGHT_ENABLE) {
        __ERROR("qnn-htp: LoRA config is not enable for input weights");
        return false;
    }

    if (!lora_config.contains(lora_weights_name)) {
        __ERROR("qnn-htp: Could not find lora weights config to apply ");
        return false;
    }

    if (_lora_enabled && lora_config[lora_weights_name].path.empty()) {
        __ERROR("qnn-htp: LoRA weights dir is empty for {}", lora_weights_name);
        return false;
    }

    if (!applyLoraStrength(
            lora_config[lora_weights_name].alpha_tensor_name,
            lora_config[lora_weights_name].alpha_tensor_val)) {
        __ERROR("qnn-htp: Could not apply Alpha tensor ");
        return false;
    }

    for (auto& variant : m_variant_list) {
        for (auto& [tname, tspec] : variant.input_specs) {
            if (tname.find("lora") != std::string::npos &&
                    tname != lora_config[lora_weights_name].alpha_tensor_name) {
                if(getBuffer(tspec) == nullptr)
                    return false;
                // lora tensor file names should be in same format as tensor names present in graph
                std::string lora_weights_file =
                        (model_basedir / fs::path(lora_config[lora_weights_name].path) / fs::path(tname + ".raw"))
                                .string();

                size_t numElements = tspec.dims.getNumElements();
                auto   scale       = tspec.quantParam[0].scale;
                auto   offset      = tspec.quantParam[0].offset;

                size_t             size = sizeof(float);
                std::vector<float> lora_weights_f32; // Temporary variable to load fp32 values
                lora_weights_f32.reserve(numElements);

                FILE* fp = fopen(lora_weights_file.c_str(), "r");
                if (fp == NULL) {
                    __ERROR("NSPModel: Error opening file: {}", lora_weights_file);
                    return false;
                }

                size_t count = fread(lora_weights_f32.data(), size, numElements, fp);
                fclose(fp);

                if (count != numElements) {
                    __ERROR("NSPModel: Could not load {} - expected file size {}",
                            lora_weights_file,
                            numElements * size);
                    return false;
                }

                // Quantize the values
                // clang-format off
                switch (tspec.dtype) {
                    case QNN_DATATYPE_UFIXED_POINT_8: QnnUtils::quantizeTensorPtr(lora_weights_f32.data(), (uint8_t*)getBuffer(tspec), offset, scale, numElements); break;
                    case QNN_DATATYPE_UFIXED_POINT_16: QnnUtils::quantizeTensorPtr(lora_weights_f32.data(), (uint16_t*)getBuffer(tspec), offset, scale, numElements); break;
                    case QNN_DATATYPE_FLOAT_16: float32ToFloat16((uint8_t *)getBuffer(tspec), lora_weights_f32.data(), numElements); break;
                    default: __ERROR("Unsupported {} datatype for {} tensor", tspec.dtype.str(), tname); return false;
                }
            }
        }
    }
    return true;
}

void QnnNspModel::dumpTensorSpecs() {
    static const char* stringFmt =
            "\t\t{ \"name\": \"%s\", \"dims\": [1, %d, %d, %d], "
            "\"bitwidth\": %d, \"dtype\": \"%s\", \"scale\": [%s], \"offset\": [%s] },\n";
    for (GraphVariant& variant : m_variant_list) {
        GraphInfo_t* graph_info = variant.graph_info;

        // Create output spec file and open it
        std::string filename = fmt::format("{}/spec.{}.json", _debug_path, graph_info->graphName);

        FILE* specFile = fopen(filename.c_str(), "w");
        if (specFile == NULL) throw std::runtime_error("Error opening file : " + filename);

        fprintf(specFile, "{\n\t\"graph_name\" : \"%s\",\n", variant.graph_name.c_str());
        for (bool io : {true, false}) {
            uint32_t n_tensors = (io) ? graph_info->numInputTensors : graph_info->numOutputTensors;
            Qnn_Tensor_t* tensor = (io) ? graph_info->inputTensors : graph_info->outputTensors;
            QnnUtils::TensorMap& tspecs = (io) ? variant.input_specs : variant.output_specs;

            fprintf(specFile, (io) ? "\t\"inputs\" : [\n" : "\t\"outputs\" : [\n");
            while (n_tensors-- > 0) {
                std::string tname                    = QnnApi::getTensorName(*tensor);
                auto& [_, dims, quant_params, dtype] = tspecs.at(tname);
                auto& [__, h, w, c, bw]              = dims;
                std::string scales;
                std::string offsets;
                QnnUtils::getQuantParamString(quant_params, scales, offsets);
                // clang-format off
                fprintf(specFile, stringFmt, tname.c_str(), h, w, c, bw, dtype.str(), scales.c_str(), offsets.c_str());
                // clang-format on
                tensor++;
            }
            fseek(specFile, -2, SEEK_CUR); // Remove trailing comma
            fprintf(specFile, "\n\t],\n");
        }
        fseek(specFile, -2, SEEK_CUR); // Remove trailing comma
        fprintf(specFile, "\n}");
        fclose(specFile);
    }
}

size_t QnnNspModel::loadKVCache(const std::string& load_path, bool chooseHigherVariant) {

    if(m_disableKvCache){
        __ERROR("KV cache is disabled, loading KV cache is not allowed");
        return false;
    }

    std::ifstream fs(load_path, std::ios::in | std::ios::binary);
    if (fs.fail()) {
        // TODO: replace with proper error handling
        __ERROR("qnn-htp: load-kv errror reading file {}", load_path);
        return 0;
    }

    CacheFileSpec spec;
    fs.read((char*)&spec, sizeof(spec));
    if (spec.magic != 0xC0DE) {
        __ERROR("qnn-htp: load-kv expected 0xC0DE found {:#x}", spec.magic);
        return 0;
    }

    bool dtype_check = true;
    // clang-format off
    switch (d_kv) {
    case QNN_DATATYPE_UFIXED_POINT_8: dtype_check = spec.dtype == CacheFileSpec::UINT8_T; break;
    case QNN_DATATYPE_UFIXED_POINT_16: dtype_check = spec.dtype == CacheFileSpec::UINT16_T; break;
    case QNN_DATATYPE_FLOAT_16: dtype_check = spec.dtype ==  CacheFileSpec::FLOAT16_T; break;
    default: __ERROR("Unsupported KV$ datatype {}", d_kv.str()); return false;
    }
    // clang-format on

    if (!dtype_check) {
        __ERROR("Model has KV$ Dtype {} but found {} in cache", d_kv.str(), int(spec.dtype));
        return false;
    }

    // clang-format off
    __DEBUG("qnn-htp: load-kv {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}",
        spec.num_tensors, spec.magic, int(spec.dtype), spec.n_heads, spec.embed_dim, spec.update_size);
    // clang-format on

    const int32_t n_valid = static_cast<int32_t>(spec.update_size);
    int32_t variant = nsp_graph_count.begin()->first; // Set KVManager to smallest variant
    if(chooseHigherVariant) variant = nsp_graph_count.rbegin()->first; // Ideal for loading KV prefix cache
    _kv_dispatcher->setVariant(variant);

    // Lock, load KeyCache then ValueCache, unlock
    for (auto& nsp_graph : m_nsp_graphs)
        nsp_graph.waitForLock("loadKVCache", _kv_update_count, false);
    for (auto& nsp_graph : m_nsp_graphs)
        nsp_graph.kvmanager->loadCache(&fs, true, n_valid, variant, spec.n_heads);
    for (auto& nsp_graph : m_nsp_graphs)
        nsp_graph.kvmanager->loadCache(&fs, false, n_valid, variant, spec.n_heads);
    for (auto& nsp_graph : m_nsp_graphs)
        nsp_graph.releaseLock("loadKVCache");

    fs.seekg(spec.num_tensors * sizeof(double), std::ios::cur);



    // Loading previous runs history input only applicable in case of tokens.
    // Embeddings history maintenance is costly in terms of memory and time.
    if(m_inputType == InputType::TOKENS) {
        token_history.clear();
        token_history.resize(n_valid);
        fs.read((char *) token_history.data(), n_valid * sizeof(int32_t));
    }
    else if(m_inputType == InputType::UNKNOWN) {
        __ERROR("Wrong type of input is found.");
        return false;
    }
    fs.close();

    m_nPast = n_valid;
    return spec.update_size;
}

bool QnnNspModel::saveKVCache(const std::string& save_path) {

    if(m_disableKvCache){
        __ERROR("KV cache is disabled, saving KV cache is not allowed");
        return false;
    }

    std::ofstream fs(save_path, std::ios::out | std::ios::binary);
    if (fs.fail()) {
        __ERROR("qnn-htp: save-kv error opening file : {}", save_path);
        throw std::runtime_error("Failed to write to cache file. Please re-check path");
    }

    const uint16_t n_valid = static_cast<uint16_t>(m_nPast);

    auto dtype = CacheFileSpec::UINT8_T;
    // clang-format off
    switch (d_kv) {
    case QNN_DATATYPE_UFIXED_POINT_8: dtype = CacheFileSpec::UINT8_T; break;
    case QNN_DATATYPE_UFIXED_POINT_16: dtype = CacheFileSpec::UINT16_T; break;
    case QNN_DATATYPE_FLOAT_16: dtype = CacheFileSpec::FLOAT16_T; break;
    default: __ERROR("Unsupported KV$ datatype {}", d_kv.str()); return false;
    }
    // clang-format on

    // Pre-calculate #tensors and n_heads to guide memory allocations
    uint32_t n_tensors = 0;
    int32_t  n_heads   = 0;
    for (auto& nsp_graph : m_nsp_graphs) {
        nsp_graph.waitForLock("saveKVCache", _kv_update_count, false);
        n_tensors += nsp_graph.kvmanager->getNumKVTensors();
        n_heads = std::max(n_heads, nsp_graph.kvmanager->getMaxNHeads());
    }

    // Save the cache file metadata
    CacheFileSpec file_spec(
            n_tensors, 0xc0de, dtype, 0x0, static_cast<uint16_t>(n_heads), m_kv_dim, n_valid
    );
    fs.write((char*)&file_spec, sizeof(file_spec));

    // Dump KeyCache and ValueCache
    for (auto& nsp_graph : m_nsp_graphs)
        nsp_graph.kvmanager->dumpCache(&fs, true, n_valid, n_heads);
    for (auto& nsp_graph : m_nsp_graphs)
        nsp_graph.kvmanager->dumpCache(&fs, false, n_valid, n_heads);

    // Dump Quantization parameters - Key scales then Value scales
    for (auto& nsp_graph : m_nsp_graphs) {
        std::vector<double>& key_scales = nsp_graph.kvmanager->getKeyScales();
        fs.write((char*)key_scales.data(), key_scales.size() * sizeof(double));
    }
    for (auto& nsp_graph : m_nsp_graphs) {
        std::vector<double>& value_scales = nsp_graph.kvmanager->getValueScales();
        fs.write((char*)value_scales.data(), value_scales.size() * sizeof(double));
    }

    // Saving previous runs history input only applicable in case of tokens.
    // Embeddings history maintenance is costly in terms of memory and time.
    if(m_inputType == InputType::TOKENS)
        fs.write((char*)token_history.data(), n_valid * sizeof(int32_t));
    else if(m_inputType == InputType::UNKNOWN) {
        __ERROR("Wrong type of input is found.");
        return false;
    }

    // Release the lock
    for (auto& nsp_graph : m_nsp_graphs)
        nsp_graph.releaseLock("saveKVCache");

    fs.flush();
    fs.close();

    return true;
}

bool QnnNspModel::applyBinarySections(std::vector<std::string>& binsection_list) {
    //apply binarysection for lora config
    for (int i = 0; i < binsection_list.size(); i++) {
        __DEBUG("qnn-htp: applyBinarySections adapters {}", binsection_list.at(i));
        if (!m_qnnApi->applyBinarySection(i, binsection_list.at(i),m_use_mmap,graph_switching)) {
            __ERROR("qnn-htp: Error in applyBinarySections {}", i);
            return false;
        }
    }
    return true;
}

bool QnnNspModel::applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val) {
    if(alpha_tensor_name.empty()) return true;
    for (auto& variant : m_variant_list) {
        if (!variant.input_specs.contains(alpha_tensor_name)) continue;

        auto& tspec          = variant.input_specs.at(alpha_tensor_name);
        auto [scale, offset] = tspec.quantParam[0];

        // clang-format off
        switch (tspec.dtype) {
        case QNN_DATATYPE_UFIXED_POINT_8: QnnUtils::quantizeTensorPtr(&alpha_val, (uint8_t*)getBuffer(tspec), offset, scale, 1); break;
        case QNN_DATATYPE_UFIXED_POINT_16: QnnUtils::quantizeTensorPtr(&alpha_val, (uint16_t*)getBuffer(tspec), offset, scale, 1); break;
        case QNN_DATATYPE_FLOAT_16: *(uint16_t *)getBuffer(tspec) = fp16_ieee_from_fp32_value(alpha_val); break;
        default: __ERROR("Unsupported alpha tensor dtype {}", tspec.dtype.str()); return false;
        }
        // clang-format on
        __DEBUG("qnn-htp: applyAlphaTensor alpha = {}", alpha_val);
        return true; // Each lora bin section should have only one alpha tensor
    }
    return false;
}

bool QnnNspModel::applyLoraAdapter(const std::string& lora_adapter_name) {
    if (lora_conf != LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE) {
        __ERROR("qnn-htp: Lora config is not enable for adapters");
        return false;
    }

    if (!lora_config.contains(lora_adapter_name)) {
        __ERROR("qnn-htp: Could not find lora adapters config to apply ");
        return false;
    }

    if (!applyLoraStrength(
                lora_config[lora_adapter_name].alpha_tensor_name,
                lora_config[lora_adapter_name].alpha_tensor_val
        )) {
        __ERROR("qnn-htp: Could not apply Alpha tensor ");
        return false;
    }

    if (!applyBinarySections(lora_config[lora_adapter_name].binsection_list)) {
        __ERROR("qnn-htp: Could not apply binary Sections ");
        return false;
    }

    for (auto& g : m_nsp_graphs) {
        for (auto& [n, variant] : g.variants) {
            variant->refreshTensorQuantParams();
        }
    }

    return true;
}

size_t QnnNspModel::getEmbeddings(std::span<float> embds) {
    qualla::Timer start;

    QnnUtils::Tensor* output_spec = nullptr;

    if(m_pooled_output)
        output_spec =  m_nsp_graphs.back().variants[run_info.n_tokens]->getOutput(m_layerNames[LayerType::POOL_OUTPUT]);
    else
        output_spec =  m_nsp_graphs.back().variants[run_info.n_tokens]->getOutput(m_layerNames[LayerType::SEQ_OUTPUT]);

    if(output_spec == nullptr) {
        __ERROR("encountered null buffer");
        throw std::runtime_error("Model is not supporting per token embedding");
    }
    const auto scale = output_spec->quantParam[0].scale;
    const auto offset = output_spec->quantParam[0].offset;


    auto output_datatype = QnnUtils::DataType(output_spec->tensor);

    int output_bw = output_spec->dtype.bw();

    uint8_t* output_buffer = (uint8_t*)getBuffer(output_spec);

    const int return_size      = m_pooled_output ? 1 : run_info.n_processed;

    if (!m_pooled_output) {
        // If multiple tokens embedding are returned, offset to the correct location in the buffer
        if (run_info.n_tokens == m_ctx_size) {
            // This was left-padded, tokens embedding are at [n_tokens - n_processed, n_tokens]
            output_buffer += (run_info.n_tokens - return_size) * m_embd_size * output_bw;
        } else {
            // This was right-padded, tokens embedding are at indexes [0, n_processed]
            output_buffer += (run_info.n_processed - 1) * m_embd_size * output_bw;
        }
    }


    const int output_len = static_cast<int>(return_size * m_embd_size);
    __TRACE("qnn-htp: get-embds for {} tokens. scale = {}, offset = {}, Returning {}",
            run_info.n_processed,
            scale,
            offset,
            output_len);

    switch (output_datatype) {
        case QNN_DATATYPE_UFIXED_POINT_8:
            deQuantizeOutputs((uint8_t*)output_buffer, embds, scale, offset, output_len);
            break;
        case QNN_DATATYPE_UFIXED_POINT_16:
            deQuantizeOutputs((uint16_t*)output_buffer, embds, scale, offset, output_len);
            break;
        case QNN_DATATYPE_FLOAT_16:
            castOutputs((uint16_t*)output_buffer, embds, output_len, output_bw);
            break;
        case QNN_DATATYPE_FLOAT_32:
            castOutputs((float*)output_buffer, embds, output_len, output_bw);
            break;
        default:
            __ERROR("Unsupported output datatype");
    }

    __DEBUG("qnn-htp: getEmbeddings complete : {} usec (return_size={})",
            start.elapsed_usec(),
            output_len);
    return output_len;
}

// Utility functions to convert structs from/to json for parsing/dumping
void from_json(const json& j, RopeScalingParams& p) {
    p.rope_type = Config::optional(j, "rope-type", RopeScalingParams::DEFAULT);
    if (p.rope_type == RopeScalingParams::ROPE_LLAMA3) {
        try {
            j.at("factor").get_to(p.llama3_params.factor);
            j.at("low-freq-factor").get_to(p.llama3_params.low_freq_factor);
            j.at("high-freq-factor").get_to(p.llama3_params.high_freq_factor);
            j.at("original-max-position-embeddings")
                    .get_to(p.llama3_params.original_max_position_embeddings);
        } catch (const json::exception& e) {
            // clang-format off
            throw std::runtime_error(fmt::format( "Parsing error for llama3 rope scaling - {}\n"
                    "llama3 requires keys ['original-max-position-embeddings', 'factor', 'low-freq-factor', 'high-freq-factor'].\n"
                    "Found config - {}", e.what(), j.dump()));
            // clang-format on
        }
    } else if (p.rope_type == RopeScalingParams::ROPE_LONGROPE) {
        try {
            j.at("original-max-position-embeddings")
                    .get_to(p.longrope_params.original_max_position_embeddings);
            j.at("long-factor").get_to(p.longrope_params.long_factor);
            j.at("short-factor").get_to(p.longrope_params.short_factor);
            if (j.contains("factor"))
                j.at("factor").get_to(p.longrope_params.factor);
            else
                p.longrope_params.factor = j.at("max-position-embeddings").get<double>() /
                                           p.longrope_params.original_max_position_embeddings;
        } catch (const json::exception& e) {
            // clang-format off
            throw std::runtime_error(fmt::format( "Parsing error for longrope scaling - {}\n"
                    "LongRope requires keys ['original-max-position-embeddings', 'factor' or 'max-position-embeddings', 'long-factor', 'short-factor'].\n"
                    "Found config - {}", e.what(), j.dump()));
            // clang-format on
        }
    }
}

void to_json(json& j, const RopeScalingParams& p) {
    j["rope-type"] = p.rope_type;
    if (p.rope_type == RopeScalingParams::ROPE_LLAMA3) {
        j["factor"]                           = p.llama3_params.factor;
        j["low-freq-factor"]                  = p.llama3_params.low_freq_factor;
        j["high-freq-factor"]                 = p.llama3_params.high_freq_factor;
        j["original-max-position-embeddings"] = p.llama3_params.original_max_position_embeddings;
    } else if (p.rope_type == RopeScalingParams::ROPE_LONGROPE) {
        j["factor"]                           = p.longrope_params.factor;
        j["long-factor"]                      = p.longrope_params.long_factor;
        j["short-factor"]                     = p.longrope_params.short_factor;
        j["original-max-position-embeddings"] = p.longrope_params.original_max_position_embeddings;
    }
}

void from_json(const json& j, PositionalEncoding& p) {
    p.type = Config::optional(j, "type", PositionalEncoding::ROPE);
    if (p.type == PositionalEncoding::ROPE) {
        p.rope_params.dims         = Config::mandatory<int32_t>(j, "rope-dim");
        p.rope_params.theta        = Config::optional<int32_t>(j, "rope-theta", 10000);
        p.rope_params.rope_scaling = Config::optional<RopeScalingParams>(j, "rope-scaling", {});
    }
}

void to_json(json& j, const PositionalEncoding& p) {
    j["type"] = p.type;
    if (p.type == PositionalEncoding::ROPE) {
        j["rope-dim"]     = p.rope_params.dims;
        j["rope-theta"]   = p.rope_params.theta;
        j["rope-scaling"] = p.rope_params.rope_scaling;
    }
}

} // namespace qualla
