//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#ifndef __QUALLA_NSP_MODEL_H_
#define __QUALLA_NSP_MODEL_H_

#include <vector>
#include <string>
#include <filesystem>
#include <atomic>
#include <span>

#include "qualla/env.hpp"
#include "qualla/detail/threadpool.hpp"

#include "QnnApi.hpp"
#include "IOTensor.hpp"

#include "nsp-kvdispatcher.hpp"
#include "qnn-utils.hpp"
#include "nsp-graph.hpp"

namespace qualla {

enum ModelArchitectureType : uint8_t{
        DECODER                = 0,
        ENCODER                = 1
};

enum LoraConfigType : uint8_t{
    LORA_DISABLE               = 0,
    LORA_INPUT_WEIGHT_ENABLE   = 1,
    LORA_ADAPTER_WEIGHT_ENABLE = 2
};

static const std::unordered_set<Qnn_DataType_t> supported_activations = {
        QNN_DATATYPE_UFIXED_POINT_8,
        QNN_DATATYPE_UFIXED_POINT_16,
        QNN_DATATYPE_INT_32,
        QNN_DATATYPE_FLOAT_16
};

struct RopeScalingParams {
    enum RopeType { DEFAULT, ROPE_LLAMA3, ROPE_LONGROPE } rope_type = DEFAULT;

    // This should be a union, but running into compilation issues with non-trivial dtr/copy-ctr
    struct {
        double factor;
        double low_freq_factor;
        double high_freq_factor;
        int    original_max_position_embeddings;
    } llama3_params {0};

    struct {
        double              factor;
        std::vector<double> long_factor;
        std::vector<double> short_factor;
        int                 original_max_position_embeddings;
    } longrope_params {0};

    RopeScalingParams() {}
};

NLOHMANN_JSON_SERIALIZE_ENUM(
        RopeScalingParams::RopeType,
        {{RopeScalingParams::DEFAULT, "default"},
         {RopeScalingParams::ROPE_LLAMA3, "llama3"},
         {RopeScalingParams::ROPE_LONGROPE, "longrope"}}
)

struct PositionalEncoding {
    enum EncodingType : uint8_t { ROPE = 0x0, ABSOLUTE = 0x1, ALIBI = 0x2, UNDEFINED = 0xff } type;
    struct {
        int32_t           dims;
        double            theta;
        RopeScalingParams rope_scaling;
    } rope_params {0};

    PositionalEncoding() { type = ROPE; }
};

NLOHMANN_JSON_SERIALIZE_ENUM(
        PositionalEncoding::EncodingType,
        {{PositionalEncoding::UNDEFINED, "undefined"},
         {PositionalEncoding::ROPE, "rope"},
         {PositionalEncoding::ABSOLUTE, "absolute"},
         {PositionalEncoding::ALIBI, "alibi"}}
)

void from_json(const json& j, PositionalEncoding& p);
void to_json(json& j, const PositionalEncoding& p);
void from_json(const json& j, RopeScalingParams& p);
void to_json(json& j, const RopeScalingParams& p);

class QnnNspModel {
  protected:
    Env& _env;

    // Populated by allocateTensors()
    // Maps tensor name to allocation block index and block offset
    std::map<std::string, std::pair<int, size_t>> tensor_alloc_info;
    bool float32ToFloat16(uint8_t* out,
                         float* in,
                         size_t numElements);

    int32_t input_width    = 1;
    int32_t input_channel  = 1;
    int32_t input_bitWidth = 4;

    int32_t embedding_length = -1;
    std::string embedding_datatype{"float32"};

    // Maps layers to their tensor names.
    std::map<LayerType, std::string> m_layerNames {
      {LayerType::INPUT, "input_ids"},
      {LayerType::OUTPUT, "logits"},
      {LayerType::TOKEN_TYPE_IDS, "token_type_ids"},
      {LayerType::POOL_OUTPUT,"pooled_output"},
      {LayerType::SEQ_OUTPUT,"sequence_output"},
      {LayerType::ATTN_MASK, "attention_mask"},
      {LayerType::POS_SIN, "position_ids_sin"},
      {LayerType::POS_COS, "position_ids_cos"},
      {LayerType::POS_IDS, "position_ids"}
    };

    std::vector<uint8_t> m_eosEmbedding;
  public:
    struct LoraConfig {
        std::string              lora_name;
        std::vector<std::string> binsection_list;   //loarv2 adapter bins filenames
        std::string              path;              //lorav1 weights directory.
        std::string              alpha_tensor_name; // loarv2 alpha tensor names
        float                    alpha_tensor_val;  //loarv2 alpha tensor values
    };
    struct Params {
        ModelArchitectureType      modelArchitectureType; // Model architecture
        std::filesystem::path      model_basedir;      // model basedir
        std::vector<std::string>   model_list;         // model filenames
        std::map<int32_t, int32_t> variant_latency;    // latency for different variants
        std::vector<std::string>   exec_select_graphs; // Execute selected graphs
        bool load_select_graphs; // Load only graphs mentioned in exec_select_graphs from the context bin, by default all graphs are loaded

        bool                              use_mmap;
        bool                              use_async_Init;
        uint64_t                          mmap_budget;
        int64_t                           spill_fill_bufsize;
        int32_t                           ctx_size;
        int32_t                           kv_dim;
        int32_t                           pad_token;
        size_t                            n_embd;
        uint32_t                          n_threads{0};
        uint64_t                          cpumask{0};
        bool                              poll{false};
        std::string                       backend_lib;
        std::string                       backend_ext_conf;
        std::string                       debug_path;
        bool                              debug_specs;
        bool                              debug_tensors;
        bool                              debug_outputs;
        bool                              debug_qnn;
        std::string                       kv_update_method;
        std::string                       lmhead_weight_dir;
        bool                              graph_switching;
        LoraConfigType                    lora_config_type;
        std::map<std::string, LoraConfig> lora_param;
        std::string                       input_layer_name;
        int32_t                           embedding_length;
        std::string                       embedding_datatype;
        bool                              pooled_output;
        bool                              disable_kv_cache;
        // Parameters for positional encodings
        PositionalEncoding positional_encoding_params;
    };

    const std::filesystem::path model_basedir;
    std::vector<std::string>    model_filelist;
    std::string                 lmhead_weight_dir;
    std::vector<int32_t>        token_history;
    std::map<int32_t, int32_t>  variant_latency;
    std::vector<std::string>    exec_select_graphs;
    bool                        load_select_graphs;

    InputType m_inputType{InputType::UNKNOWN};

    LoraConfigType                     lora_conf;
    std::map<std::string, LoraConfig>  lora_config;
    // QNN specific variables
    const bool                m_sharedBuffer{true};
    std::unique_ptr<QnnApi>   m_qnnApi;
    std::unique_ptr<IOTensor> m_ioTensor{nullptr};
    int64_t                   spill_fill_buffer_size;
    bool                      m_use_mmap{false};
    bool                      m_use_async_Init{true};
    uint64_t                  mmap_budget;
    bool                      graph_switching{false};
    size_t                    n_embd;


    bool m_pooled_output{true};
    bool m_disableKvCache{false};
    // Model parameters
    ModelArchitectureType m_modelArchitectureType;
    int32_t m_ctx_size{-1};
    int32_t m_vocab_size{-1};
    int32_t m_kv_dim{-1};
    int32_t m_embd_size{-1};
    int32_t m_pad_token{-1};

    size_t m_embeddingBufferSize{0};

    QnnUtils::DataType d_input{QNN_DATATYPE_INT_32}, d_kv{QNN_DATATYPE_UFIXED_POINT_8},
            d_attn_map{QNN_DATATYPE_UFIXED_POINT_16}, d_token_type{QNN_DATATYPE_INT_32};

    // int32_t attention_mask_bitwidth{2}, position_id_bitwidth{2};

    // Information regarding model execution settings and last inference
    struct RunInfo {
        int32_t n_tokens;
        size_t  n_processed;

        std::vector<int32_t> tokens;
    } run_info{-1, 0, {}};

    // Model specific variables
    uint32_t m_num_graphs;
    bool     _lora_enabled{false};
    bool     _lmhead_weight_input{false};

    // QnnNspGraph contains all GraphVariants for a specific split (with index=split_idx)
    std::vector<QnnNspGraph> m_nsp_graphs;
    // GraphVariant represents one input size within one split (e.g. KV$_split_1)
    std::vector<GraphVariant> m_variant_list;

    // For ease of usage: Map from graph name to the corresponding GraphVariant
    std::unordered_map<std::string, GraphVariant*> m_graph_map;
    // This map records how many graphs have been loaded for a particular input size
    std::map<int32_t, int32_t> nsp_graph_count;

    bool       _threaded{false};
    uint64_t   _cpumask{0};
    ThreadPool threadpool;

    KVManagerMode _kv_update_method{POINTER_SHIFT};

    int32_t                       _kv_update_count{0};
    std::unique_ptr<KVDispatcher> _kv_dispatcher;

    std::string _backend_lib;
    std::string _backend_ext_conf;

    // Store some pointers for easier access
    QnnUtils::Tensor* t_input_ids{nullptr};
    QnnUtils::Tensor* t_attn_mask{nullptr};
    QnnUtils::Tensor* t_token_type_ids{nullptr};

    // Variables for positional encodings
    PositionalEncoding m_positional_encoding;
    QnnUtils::DataType d_pos{QNN_DATATYPE_UFIXED_POINT_16};
    // PositionalEncodingType::ABSOLUTE OR PositionalEncodingType::ALIBI
    QnnUtils::Tensor* t_position_ids{nullptr};
    // PositionalEncodingType::ROPE variables
    int32_t m_pos_dim{-1};       // Dimension of positional embedding tensor (incl partial_factor)
    void*   rope_sin{nullptr};   // Pre-calculated RoPE sin table of size [ctx_size, m_pos_dim]
    void*   rope_cos{nullptr};   // Pre-calculated RoPE cos table of size [ctx_size, m_pos_dim]

    QnnUtils::Tensor* t_position_ids_sin{nullptr};
    QnnUtils::Tensor* t_position_ids_cos{nullptr};

    // n_past defines number of population of kvcache
    size_t m_nPast{0};

    // Self-Specualtive Decoding
    // This prefix is not for input tokens, but just for speical tokens
    // Only the special tokens from the offset should attend the kv prefix
    int32_t _size_to_skip_kv_prefix{0};
    int32_t _offset_to_apply_kv_prefix{0};

    // Keep track of inference count
    int m_inference_count = 0;

    // Debug mode settings
    bool        _debug_specs{false};
    bool        _debug_tensors{false};
    bool        _debug_outputs{false};
    bool        _debug_qnn{false};
    std::string _debug_path;

    QnnNspModel(Env& env, const Params& params);

    ~QnnNspModel();

    bool initializeModel(void);
    bool validateModel(void);
    bool initializeIOTensors(void);
    bool initializeTensorPointers();
    bool initializeKVManager();
    bool calculate_rope_embeddings(void);
    bool load_lmhead_weight_as_input(void);
    bool flushLoraWeightsBuffers(void);

    template <typename DType>
    bool setupAttentionMask(
            bool                     pad_left,
            int                      n_tokens,
            int                      n_inputs,
            int                      n_past,
            std::span<const int32_t> attention_map,
            size_t                   n_skip_prefix,
            size_t                   n_apply_prefix_offset
    );

    bool setupAttentionMaskFP16(
	   bool                      pad_left,
	   int                       n_tokens,
	   int                       n_inputs,
  	   int                       n_past,
	   std::span<const int32_t>  attention_map,
	   size_t                    n_skip_prefix,
	   size_t                    n_apply_prefix_offset);

    bool setupRopePositionEmbeddingFP16(
            bool                     pad_left,
            int                      n_tokens,
            int                      n_inputs,
            int                      n_past,
            std::span<const int32_t> attention_map,
            size_t                   n_skip_prefix,
            size_t                   n_apply_prefix_offset
    );

    template <typename DType>
    bool setupRopePositionEmbedding(
            bool                     pad_left,
            int                      n_tokens,
            int                      n_inputs,
            int                      n_past,
            std::span<const int32_t> attention_map,
            size_t                   n_skip_prefix,
            size_t                   n_apply_prefix_offset
    );

    template <typename DType>
    bool setupAlibiPositionEmbedding(
            bool pad_left,
            int n_tokens,
            int n_inputs,
            int n_past
    );

    bool setupInputTensors(
            std::span<int32_t>       tokens,
            int32_t                  n_past,
            std::span<const int32_t> attention_map,
            size_t                   n_skip_prefix,
            size_t                   n_apply_prefix_offset
    );

    bool setupInputTensors(
            std::span<uint8_t>       embedding,
            int32_t                  n_past,
            std::span<const int32_t> attention_map,
            size_t                   n_skip_prefix,
            size_t                   n_apply_prefix_offset
    );

    bool quantizeInput(float* in, size_t tensorOffset, size_t length);

    size_t getEmbeddingBufferSize();

    size_t runInference(
            const std::vector<int32_t>& tokens,
            const std::vector<int32_t>& attention_map,
            std::vector<float>&         output,
            bool                        output_all = false
    );

    size_t runInference(
        std::vector<uint8_t>&       embeddings,
        const std::vector<int32_t>& attention_map,
        std::vector<float>&         output,
        bool                        output_all = false
    );
    
    bool cacheEosEmbedding(std::vector<uint8_t>& eosEmbedding);

    bool setKVCacheNPast(size_t n_past, const std::vector<bool>& selected);

    size_t getEmbeddings(std::span<float> embds);

    size_t getDequantLogits(std::span<float> logits, bool logits_all = false);

    bool debugOutputs(QnnUtils::Tensor* outTensor, std::string& outTensorName);

    size_t loadKVCache(const std::string& load_path, bool chooseHigherVariant=false);
    bool   saveKVCache(const std::string& save_path);
    bool   applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val);
    bool   applyLoraAdapter(const std::string& lora_adapter_name);
    bool   applyBinarySections(std::vector<std::string>& binsection_list);
    bool   applyLoraWeights(const std::string& lora_weights_name);

  protected:
    // Internal functions to separate different runInference logic
    int32_t selectVariantStrategy(int32_t n_inputs, int32_t n_past, int32_t cur_variant);
    bool    runInferenceHelper(bool pipeline, int32_t* total_wait, int32_t* total_exec);

    inline bool  updateTensorPointer(GraphVariant& variant, std::string& key, QnnUtils::Tensor*& t);
    inline void* getBuffer(QnnUtils::Tensor& spec) { return m_ioTensor->getBuffer(spec.tensor); }
    inline void* getBuffer(QnnUtils::Tensor* spec) { return m_ioTensor->getBuffer(spec->tensor); }
    inline size_t getBufferSize(QnnUtils::Tensor& spec) { return spec.dims.getSize(); }
    inline size_t getBufferSize(QnnUtils::Tensor* spec) { return spec->dims.getSize(); }

    void dumpTensorSpecs();
};

} // namespace qualla

#endif
