//==============================================================================
//
//  Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================

#pragma once

#include "BackendExtensions.hpp"
#include "QnnConfig.hpp"
#include "QnnHtpPerfInfrastructure.h"
#include "QnnHtpDevice.h"
#include "qnn-utils.hpp"
#include "IOTensor.hpp"

#include <memory>
#include <mutex>

#define QNN_IO_TENSOR_DEBUG 0

enum KVManagerMode { POINTER_SHIFT = 0x0, SHIFT_CONCAT = 0x1 };

using qualla::QnnUtils::QuantParam;

#define QUALLA_QNN_API_VERSION                                                                     \
    (QNN_API_VERSION_MAJOR * 10000 + QNN_API_VERSION_MINOR * 100 + QNN_API_VERSION_PATCH)

static std::map<Qnn_DataType_t, size_t> g_qnnDataTypeToSize = {
        {QNN_DATATYPE_INT_8, 1},
        {QNN_DATATYPE_INT_16, 2},
        {QNN_DATATYPE_INT_32, 4},
        {QNN_DATATYPE_INT_64, 8},
        {QNN_DATATYPE_UINT_8, 1},
        {QNN_DATATYPE_UINT_16, 2},
        {QNN_DATATYPE_UINT_32, 4},
        {QNN_DATATYPE_UINT_64, 8},
        {QNN_DATATYPE_FLOAT_16, 2},
        {QNN_DATATYPE_FLOAT_32, 4},
        {QNN_DATATYPE_SFIXED_POINT_8, 1},
        {QNN_DATATYPE_SFIXED_POINT_16, 2},
        {QNN_DATATYPE_SFIXED_POINT_32, 4},
        {QNN_DATATYPE_UFIXED_POINT_8, 1},
        {QNN_DATATYPE_UFIXED_POINT_16, 2},
        {QNN_DATATYPE_UFIXED_POINT_32, 4},
        {QNN_DATATYPE_BOOL_8, 1},
};

class QnnApi {
  private:
    const uint32_t s_graphConfigsReserveCount = 16;

    // Model vars
    typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)(
            const QnnInterface_t*** providerList,
            uint32_t*               numProviders
    );
    typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)(
            const QnnSystemInterface_t*** providerList,
            uint32_t*                     numProviders
    );

    // Graph Related Function Handle Types
    typedef ModelError_t (*ComposeGraphsFnHandleType_t)(
            Qnn_BackendHandle_t,
            QNN_INTERFACE_VER_TYPE,
            Qnn_ContextHandle_t,
            const GraphConfigInfo_t**,
            const uint32_t,
            GraphInfo_t***,
            uint32_t*,
            bool,
            QnnLog_Callback_t,
            QnnLog_Level_t
    );

    typedef ModelError_t (*GenAIComposeGraphsFnHandleType_t)(
            Qnn_BackendHandle_t,
            QNN_INTERFACE_VER_TYPE,
            Qnn_ContextHandle_t,
            const GraphConfigInfo_t**,
            const uint32_t,
            uint32_t*    inputDim,
            uint32_t     inputRank,
            uint32_t*    outputDim,
            uint32_t     outputRank,
            uint32_t*    kvDim,
            uint32_t     kvRank,
            Qnn_Param_t* params,
            uint32_t     numParam,
            GraphInfo_t***,
            uint32_t*,
            bool,
            QnnLog_Callback_t,
            QnnLog_Level_t
    );

    typedef ModelError_t (*FreeGraphInfoFnHandleType_t)(GraphInfo_t***, uint32_t);

    void* m_libModelHandle{nullptr};
    void* m_backendHandle{nullptr};
    void* m_backendLibraryHandle{nullptr};

    QNN_INTERFACE_VER_TYPE             m_qnnInterface{nullptr};
    QNN_SYSTEM_INTERFACE_VER_TYPE      m_qnnSystemInterface{nullptr};
    std::unique_ptr<BackendExtensions> m_backendExtensions{nullptr};
    ComposeGraphsFnHandleType_t        m_composeGraphsFnHandle{nullptr};
    GenAIComposeGraphsFnHandleType_t   m_genaiComposeGraphsFnHandle{nullptr};
    FreeGraphInfoFnHandleType_t        m_freeGraphInfoFnHandle{nullptr};
    uint32_t                           m_backendId{0};
    Qnn_LogHandle_t    m_logHandle{nullptr};
    Qnn_DeviceHandle_t m_deviceHandle{nullptr};

    Qnn_ProfileHandle_t m_profileBackendHandle{nullptr};

    std::vector<Qnn_ContextHandle_t>                    m_contextVec;
    std::unordered_map<GraphInfo*, Qnn_ContextHandle_t> m_contextMap;
    uint32_t                                            m_graphsCount{0};
    int32_t                                             graphCountPerContext{-1};
    GraphInfo_t**                                       m_graphsInfo;
    std::unordered_map<std::string, uint32_t>           m_graphNameToIndex;
    std::unordered_map<std::string, GraphInfo*>         m_graphNameToInfo;
    std::unordered_map<std::string, uint32_t>           m_graphNameToContextIdx;
    std::unordered_map<uint32_t, Qnn_ContextHandle_t>   m_contextIdtoHandle;
    std::mutex                                          m_updateCallBackMutex;

    // Useful Structure for IO Esimtation
    std::unordered_map<int,qualla::QnnUtils::TensorMap> m_graphtoIOMap; // stores {GraphId -> IOTensorMap}
    typedef int CtxBitVector;
    std::map<CtxBitVector, std::map<std::string, size_t>> m_contextAllocMap; // stores {Translated ContextId -> {Tensor name, size}}
    std::map<std::string, std::pair<int, size_t>> m_tensorAllocInfo; // stores {Tensor name -> (fd of RPC buffer, offset)}
    std::unordered_map<uint32_t, uint32_t> m_graphIdxToContextIdx; // stores {Graph Idx -> Context Idx}
    std::unordered_map<std::string,std::shared_ptr<uint8_t>> m_adapterNameToBuffer;

    uint32_t              m_backendConfigCount{0};
    QnnBackend_Config_t** m_backendConfigs{nullptr};

    QnnHtpDevice_PerfInfrastructure_t* m_perfInfra{nullptr};
    uint32_t                           m_powerConfigId = 1;

     // Useful Structure for IO Esimtation
    IOTensor*             m_ioBufferMgr{nullptr};
    int32_t               m_ctxSize{-1};
    int32_t               m_kvDim{-1};
    bool                  m_loraWeightEnabled{false};
    bool                  m_lmHeadWeightInput{false};
    KVManagerMode         m_kvUpdateMethod{POINTER_SHIFT};

    bool m_isLogInitialized{false};
    bool m_isBackendInitialized{false};
    bool m_isContextCreated{false};

    // Variable to keep track of debug mode
    bool m_DebugModeRequested;
    bool m_debugQnn{false};

    // Variable to indicate whether to mmap context bins or read them in memory
    bool m_mmapContextBins;
    bool m_isDeviceCreated = false;

    std::vector<std::pair<uint8_t*, uint64_t>> m_contextBinBuffersToBeCleared;

    void setDeviceStatus(bool status) { m_isDeviceCreated = status; }
    bool getDeviceStatus() { return m_isDeviceCreated; }
    bool getContextConfigs(
            QnnContext_Config_t***          configs,
            uint32_t&                       contextConfigCount,
            Qnn_Priority_t                  contextPriority,
            bool                            graphSwitching   = false,
            const std::vector<std::string>& execSelectGraphs = {},
            bool                            loadSelectGraphs = false
    );
    bool mergeAllContextConfigs(
            QnnContext_Config_t*** allCustomContextConfigs,
            QnnContext_Config_t**  customConfigs,
            QnnContext_Config_t**  contextConfigs,
            uint32_t               customConfigCount,
            uint32_t               contextConfigCount
    );
    bool freeContextConfigs(QnnContext_Config_t** contextConfigs, uint32_t contextConfigCount);
    bool setGraphConfigsBeforeExecute(
            Qnn_GraphHandle_t   graphHandle,
            QnnGraph_Config_t** graphConfigs,
            uint32_t            configCount
    );

    bool getQnnInterface(std::string backendPath);
    bool getQnnSystemInterface(std::string systemLibraryPath);
    bool loadModel(std::string model_path);
    bool initializeLogging(const QnnLog_Level_t& logLevel, bool debug_qnn);
    void terminateLog();
    bool initializeBackendExtensions(
            BackendExtensionsConfigs backendExtensionsConfig,
            PerfProfile              parsedPerfProfile,
            bool                     debug_qnn
    );
    bool initializeBackend();
    bool terminateBackend();
    bool createDevice();
    bool freeDevice();
    bool createContext(ContextConfigs contextConfig);
    bool freeContext();
    bool composeGraphs(std::vector<GraphConfigs> graphConfigs);
    bool composeGraphs(
            std::vector<GraphConfigs> graphConfigs,
            uint32_t*                 inputDim,
            uint32_t                  inputRank,
            uint32_t*                 outputDim,
            uint32_t                  outputRank,
            uint32_t*                 kvDim,
            uint32_t                  kvRank,
            Qnn_Param_t*              params,
            uint32_t                  numParams
    );
    bool mapAndGetContextBinaryInfo(
            const bool                            use_mmap,
            std::shared_ptr<uint8_t>&             buffer,
            const std::string                     binaryPath,
            const uint64_t                        bufferSize,
            const size_t                          contextIdx,
            const bool                            graphSwitching,
            QnnSystemContext_Handle_t             sysCtxHandle,
            const QnnSystemContext_BinaryInfo_t** binaryInfo
    );

    bool parseIOTensorsAndAccumulate();
    bool registerTensorsWithBackend(uint32_t& graphIdx);

    bool finalizeGraphs();
    bool initializePerformance();
    bool destroyPerformance();
    bool boostPerformance();
    bool resetPerformance();
    bool checkCapabilityOfCreateAsync(bool& propRet);

    bool initProfiling();
    bool extractBackendProfilingInfo(
            Qnn_ProfileHandle_t                                 profileHandle,
            std::map<std::string, std::pair<double, uint16_t>>& timeLogs,
            std::string                                         graphName
    );
    bool extractProfilingSubEvents(
            QnnProfile_EventId_t                                profileEventId,
            std::map<std::string, std::pair<double, uint16_t>>& timeLogs,
            std::string                                         graphName
    );
    bool extractProfilingEvent(
            QnnProfile_EventId_t                                profileEventId,
            std::map<std::string, std::pair<double, uint16_t>>& timeLogs,
            std::string                                         graphName
    );
    bool extractBackendProfilingInfo(Qnn_ProfileHandle_t profileHandle);
    bool extractProfilingSubEvents(QnnProfile_EventId_t profileEventId);
    bool extractProfilingEvent(QnnProfile_EventId_t profileEventId);

    Qnn_ContextHandle_t getContextWithId(uint32_t contextId) {
        return m_contextIdtoHandle[contextId];
    }

  public:
    QnnApi() {};
    ~QnnApi();

    bool           freeGraphs();
    static QnnApi& getInstance();
#if QUALLA_QNN_API_VERSION >= 21700
    static void contextNotifyFn(
            Qnn_ContextHandle_t                          context,
            Qnn_GraphHandle_t                            graph,
            const char*                                  graph_name,
            QnnContext_createFromBinaryAsyncNotifyType_t completeType,
            void*                                        notifyParam,
            Qnn_ErrorHandle_t                            status
    );
#endif
    bool createFromBinary(
            std::vector<std::string>        cachedBinariesPathVec,
            ContextConfigs                  contextConfig,
            int64_t                         spill_fill_buffer_size = 0,
            uint64_t                        mmap_budget            = 0,
            bool                            graphSwitching         = false,
            const std::vector<std::string>& execSelectGraphs       = {},
            bool                            loadSelectGraphs       = false
    );
#if QUALLA_QNN_API_VERSION >= 21700
    bool createFromBinaryListAsync(
            std::vector<std::string>        cachedBinariesPathVec,
            ContextConfigs                  contextConfig,
            int64_t                         spill_fill_buffer_size = 0,
            uint64_t                        mmap_budget            = 0,
            bool                            graphSwitching         = false,
            const std::vector<std::string>& execSelectGraphs       = {},
            bool                            loadSelectGraphs       = false
    );
#endif
    bool initialize(
            std::string                     backendPath,
            std::vector<std::string>        modelPathOrCachedBinaryPathVec,
            BackendExtensionsConfigs        backendExtensionsConfig,
            PerfProfile                     parsedPerfProfile      = PerfProfile::BURST,
            ContextConfigs                  contextConfig          = ContextConfigs(),
            std::vector<GraphConfigs>       graphConfigs           = {},
            bool                            loadFromCachedBinary   = false,
            std::string                     systemLibraryPath      = "",
            bool                            debugModeRequested     = false,
            int64_t                         spill_fill_buffer_size = 0,
            bool                            mmapContextBins        = false,
            bool                            asyncInit              = true,
            uint64_t                        mmap_budget            = 0,
            bool                            debug_qnn              = false,
            bool                            graphSwitching         = false,
            const std::vector<std::string>& execSelectGraphs       = {},
            bool                            loadSelectGraphs       = false
    );

    bool registerOpPackage(std::string opPackagePath);

    void setIOTensorBufferMgr(IOTensor* ioBufferMgr){
        m_ioBufferMgr = ioBufferMgr;
    }

    void setKVDim(int32_t kvDim){
        m_kvDim = kvDim;
    }

    void setContextSize(int32_t ctxSize){
       m_ctxSize = ctxSize;
    }

    void setKVUpdateMethod(KVManagerMode kvUpdateMethod){
       m_kvUpdateMethod = kvUpdateMethod ;
    }

    std::map<std::string, std::pair<int, size_t>>* getTensorAllocInfo(){
        return &m_tensorAllocInfo;
    }

    bool getLmHeadWeightInputEnabled(){
       return m_lmHeadWeightInput;
    }

    bool getLoraWeightEnabled(){
       return m_loraWeightEnabled;
    }
    // Initalize with OpPackage
    bool initialize(
            std::string               backendPath,
            std::string               modelPath,
            std::string               opPackage,
            ContextConfigs            contextConfig,
            std::vector<GraphConfigs> graphConfigs,
            uint32_t*                 inputDim,
            uint32_t                  inputRank,
            uint32_t*                 outputDim,
            uint32_t                  outputRank,
            uint32_t*                 kvDim,
            uint32_t                  kvRank,
            Qnn_Param_t*              params,
            uint32_t                  numParams,
            bool                      debugModeRequested
    );

    bool graphExecute(
            Qnn_Tensor_t*                                       input,
            Qnn_Tensor_t*                                       output,
            std::string                                         graphName,
            std::map<std::string, std::pair<double, uint16_t>>& timeLogs
    );

    bool applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch);

    bool applyBinarySection(uint32_t graphId, std::string binSectionPath);

    QNN_INTERFACE_VER_TYPE*           getQnnInterfaceVer() { return &m_qnnInterface; };
    GraphInfo_t**&                    getGraphsInfo() { return m_graphsInfo; };
    uint32_t                          getGraphsCount() { return m_graphsCount; };
    int32_t                           getGraphCountPerContext() { return graphCountPerContext; }
    std::vector<Qnn_ContextHandle_t>& getContexts() { return m_contextVec; };
    const Qnn_ContextHandle_t         getContexts(GraphInfo_t* const graph) {
        return m_contextMap.at(graph);
    };

    void updateContext(Qnn_ContextHandle_t context, uint32_t contextId) {
        std::lock_guard<std::mutex> lock(m_updateCallBackMutex);
        m_contextVec.push_back(context);
        m_contextIdtoHandle[contextId] = context;
    }

    void updateQnnApiGraphsandContextsInfo(
            std::string       graphName,
            Qnn_GraphHandle_t graph,
            uint32_t          contextId
    ) {
        // set graph handle to GraphInfo
        std::lock_guard<std::mutex> lock(m_updateCallBackMutex);
        m_graphNameToInfo[graphName]->graph = graph;
        m_graphNameToContextIdx[graphName]  = contextId;
        m_graphsCount++;
    }

    static inline size_t getDataTypeSize(const Qnn_DataType_t& datatype) {
        return g_qnnDataTypeToSize[datatype];
    }
    static inline std::string getTensorName(const TensorWrapper& tensorWrapper) {
        return GET_TENSOR_WRAPPER_NAME(tensorWrapper);
    }
    static bool getTensorQuantParams(
            const Qnn_Tensor_t*      tensor,
            std::vector<QuantParam>& quantParamsVec
    );
    static bool getTensorShape(std::vector<size_t>& tensorDims, const TensorWrapper& tensorWrapper);
    static inline Qnn_DataType_t getTensorDtype(const Qnn_Tensor_t* tensor) {
        return QNN_TENSOR_GET_DATA_TYPE(tensor);
    }

    bool getTensorNameAndShape(
            std::string&         tensorName,
            std::vector<size_t>& tensorDims,
            TensorWrapper&       tensorWrapper
    );
    static void qnnLogCallback(
            const char*    fmt,
            QnnLog_Level_t level,
            uint64_t       timestamp,
            va_list        args
    );
    bool updateIOEncodings(std::shared_ptr<uint8_t>& buffer,
                           uint64_t  bufferSize,
                           uint32_t graphIndex);

    bool createFromBinary(std::vector<std::string> cachedBinariesPathVec);

    bool initialize(
        std::string               backendPath,
        std::vector<std::string>  modelPathOrCachedBinaryPath
    );
};
