//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#pragma once

#include "qualla/env.hpp"

#include "QnnApi.hpp"
#include "IOTensor.hpp"
#include "qnn-utils.hpp"
#include "nsp-kvmanager.hpp"

namespace qualla {
enum class LayerType {
  INPUT,
  OUTPUT,
  ATTN_MASK,
  POS_SIN,
  POS_COS,
  POS_IDS,
  TOKEN_TYPE_IDS,
  POOL_OUTPUT,
  SEQ_OUTPUT
};
struct GraphVariant {
    int32_t     n_tokens;
    int32_t     ctx_size{-1};
    std::string graph_name;

    // QNN API specific variables
    GraphInfo_t*        graph_info;
    Qnn_ContextHandle_t context_handle;

    QnnUtils::TensorMap input_specs;
    QnnUtils::TensorMap output_specs;

    std::map<LayerType, std::string>& m_layerNames;

    GraphVariant() = delete;
    GraphVariant(GraphInfo_t* g_info, Qnn_ContextHandle_t qnn_ctx, int32_t n_ctx, std::map<LayerType, std::string>& layerNames);
    QnnUtils::Tensor* getTensor(const std::string& tensor_name) {
        QnnUtils::Tensor* ret = getInput(tensor_name);
        return (ret != nullptr) ? ret : getOutput(tensor_name);
    }
    QnnUtils::Tensor* getInput(const std::string& tensor_name) {
        return input_specs.contains(tensor_name) ? &input_specs.at(tensor_name) : nullptr;
    }
    QnnUtils::Tensor* getOutput(const std::string& tensor_name) {
        return output_specs.contains(tensor_name) ? &output_specs.at(tensor_name) : nullptr;
    }

    bool refreshTensorQuantParams();

  private:
    size_t determineGraphInputSize();
};

/**
 * The idea behind QnnNspGraph is to represent "common" graphs
 * For instance, both BERT-mode and KV$-mode are the same graph with different input sizes
 * QnnNspGraph will contain and manage both BERT-split-n and KV$mode-split-n
 * I/O tensors are mostly shared between these graphs, and can be managed collectively
*/
class QnnNspGraph {
  private:
    int  _idx;
    Env& _env;

    int32_t ctx_size{-1};

    // Useful pointers for graph execution (managed by NSPModel)
    QnnApi*   g_qnn_api;
    IOTensor* g_buffer_mgr;

    bool                     _threaded;
    std::mutex*              _lock;    // Locks whenever KV$ is being used or updated
    std::condition_variable* _lock_cv; // Wake up _lock when jobs are complete

    KVManagerMode _kv_update_method{POINTER_SHIFT};

    int32_t run_wait_time, run_exec_time; // Add more stats into a struct

    // Debug mode settings
    bool        _debug_specs{false};
    bool        _debug_tensors{false};
    std::string _debug_path;

  public:
    int32_t          _counter{-1};
    NewNSPKVManager* kvmanager{nullptr};

    // TODO: Remove this reference
    std::map<std::string, std::pair<int, size_t>>* tensor_alloc_info;

    // Keys represent input_id size (1<=input_size<=ctx_size)
    // Values are graph description for that input_id size
    std::map<int32_t, GraphVariant*> variants;

    QnnNspGraph(
            int       idx,
            Env&      env,
            int32_t   n_ctx,
            QnnApi*   qnnApi,
            IOTensor* ioTensor,
            bool      threaded
    );
    ~QnnNspGraph();

    bool addGraph(GraphVariant* graph_spec);
    void printAvailableConfigs();
    void registerKVManager(NewNSPKVManager* mgr);

    // Given an input size, picks the correct model among the ones available
    // This is likely not easy to implement as there's implications on KV$ management
    size_t getOptimalModelInputSize(size_t n_past, size_t input_size) { return 0; }

    GraphVariant* operator[](int32_t idx) { return variants.at(idx); }

    bool                              execute(int n_tokens, int n_inference, int32_t wait_count);
    const std::pair<int32_t, int32_t> getExecutionStats() { return {run_wait_time, run_exec_time}; }

    void setDebugMode(bool debug_specs, bool debug_tensors, std::string debug_path) {
        _debug_path    = debug_path;
        _debug_specs   = debug_specs;
        _debug_tensors = debug_tensors;
    }
    void dumpTensors(GraphVariant* const variant, bool mode, int n_inference) const;

    // Mutex functions
    void wakeUpLock();
    void waitForLock(std::string requester = "");
    void waitForLock(std::string requester, int32_t wait_counter, bool poll);
    void releaseLock(std::string requester = "");
    bool registerPointerShift(int32_t variant, int32_t ptr_offset);
};

} // namespace qualla
