Program Listing for File QnnModel.hpp

Return to documentation for file (share/converter/jni/QnnModel.hpp)

//==============================================================================
//
//  Copyright (c) 2019-2023 Qualcomm Technologies, Inc.
//  All Rights Reserved.
//  Confidential and Proprietary - Qualcomm Technologies, Inc.
//
//==============================================================================
#ifndef QNNMODEL_H  // Use guard to prevent include same code twice
#define QNNMODEL_H
#pragma once  // Not work when compile ARM64X with MSVC

#include <limits>
#include <map>
#include <string>
#include <vector>

#include "QnnInterface.h"
#include "QnnLog.h"
#include "QnnModelPal.hpp"
#include "QnnWrapperUtils.hpp"

namespace qnn_wrapper_api {

class QnnModel {
 public:
  ~QnnModel() = default;

  ModelError_t initialize(const Qnn_BackendHandle_t& backendHandle,
                          const QNN_INTERFACE_VER_TYPE& qnnInterface,
                          const Qnn_ContextHandle_t& context,
                          const char* graphName,
                          bool debug,
                          uint8_t doNodeValidations              = 1,
                          const QnnGraph_Config_t** graphConfigs = nullptr);

  ModelError_t addTensor(const char* nodeName, Qnn_Tensor_t* tensor, bool saveTensor = true);

  ModelError_t addTensor(const char* nodeName, Qnn_Tensor_t tensor, bool saveTensor = true);

  ModelError_t getQnnTensor(const char*& nodeName, const char*& tensorName, Qnn_Tensor_t& tensor);

  ModelError_t addNode(Qnn_OpConfigVersion_t version,
                       const char* name,
                       const char* packageName,
                       const char* type,
                       Qnn_Param_t* params,
                       uint32_t numOfParams,
                       const char** inputNames,
                       uint32_t numOfInputs,
                       Qnn_Tensor_t* outputTensors,
                       uint32_t numOfOutputs);

  Qnn_GraphHandle_t getQnnGraph() { return m_graph; }

  std::string getQnnGraphName() { return m_graphName; }

  std::vector<Qnn_Tensor_t> getGraphInputTensors() { return m_modelInputTensors; }

  std::vector<Qnn_Tensor_t> getGraphOutputTensors() { return m_modelOutputTensors; }

  std::map<std::string, std::vector<std::string>> getOutputTensorMap() {
    return m_modelOutputTensorMap;
  }

  ModelError_t finalize(Qnn_ProfileHandle_t profile = nullptr, Qnn_SignalHandle_t signal = nullptr);

  ModelError_t freeCachedTensors();

 private:
  Qnn_GraphHandle_t m_graph = nullptr;
  std::string m_graphName;
  bool m_debug = false;  // flag to indicate if requested graph is to be run in debug mode
  // (i.e. all intermediate tensors will be accessible to client)
  // flag to indicate whether all addNode calls need to be validated
  bool m_doNodeValidations = true;

  std::vector<Qnn_Tensor_t> m_modelInputTensors;
  std::vector<Qnn_Tensor_t> m_modelOutputTensors;
  // keeps track of graph tensors to enable creating Qnn nodes from tensor names
  std::map<std::string, Qnn_Tensor_t> m_modelTensorsMap;
  std::map<std::string, std::vector<std::string>> m_modelOutputTensorMap;

  // Qnn Backend Interface Api
  QNN_INTERFACE_VER_TYPE m_qnnInterface;
  Qnn_BackendHandle_t m_backendHandle;

};  // QNN_MODEL_CLASS

ModelError_t getGraphInfoFromModels(QnnModel* models,
                                    uint32_t numModels,
                                    GraphInfoPtr_t** graphsInfo);

ModelError_t freeGraphsInfo(GraphInfoPtr_t** graphsInfo, uint32_t numGraphs);
}  // namespace qnn_wrapper_api
#endif