//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#ifndef __QNN_HTP_H__
#define __QNN_HTP_H__

#include <vector>
#include <string>

#include <qualla/engine.hpp>
#include <qualla/detail/config.hpp>
#include <qualla/detail/timer.hpp>
#include <qualla/detail/onload.hpp>

#include <fmt/format.h>

#include "nsp-model.hpp"

namespace qualla {

class NspEngine : public Engine {
  protected:
    QnnNspModel::Params _params;

    std::unique_ptr<QnnNspModel> _model;

  public:
    NspEngine(Context& ctx, const qualla::json& json);
    virtual ~NspEngine();

    virtual size_t process(
            const std::vector<int32_t>& tokens,
            std::vector<float>&         logits,
            bool                        logits_all
    ) override;

    virtual size_t process(
            const std::vector<int32_t>& tokens,
            const std::vector<int32_t>& attention_map,
            std::vector<float>&         logits,
            bool                        logits_all
    ) override;

    virtual size_t process(
        std::vector<uint8_t>&       embeddings,
        const std::vector<int32_t>& attention_map,
        std::vector<float>&         logits,
        bool                        logits_all
    ) override;
    
    /** Stores a precomputed EOS embedding vector. */
    virtual bool cacheEosEmbedding(std::vector<uint8_t>& eosEmbedding) override;

    void getInputQuantParam(double& scale, int& offset) {

        auto tmp = _model->t_input_ids->quantParam[0];
        scale    = tmp.scale;
        offset   = tmp.offset;
    }

    virtual qualla::InputType getInputType() override;

    virtual size_t getEmbeddingBufferSize() override;

    virtual bool   updateKV(size_t n_past) override;
    virtual bool   updateKV(size_t n_past, const std::vector<bool>& selected) override;
    virtual bool   save(const std::string& name) override;
    virtual size_t restore(const std::string& name, bool chooseHigherVariant) override;
    virtual void   reset() override;

    virtual bool         set(qualla::json data) override;
    virtual qualla::json get() override;

    virtual bool load() override;
    virtual bool unload() override;

    virtual bool applyLoraAdapter(std::string lora_adapter_name) override;
    virtual bool applyLoraStrength(std::string tensor_name, float tensor_val) override;
};

} // namespace qualla

#endif
