Skip to content
This repository has been archived by the owner on Dec 30, 2024. It is now read-only.

Commit

Permalink
Qualcomm AI Engine Direct - support multi-context, dlbc (pytorch#2450)
Browse files Browse the repository at this point in the history
Summary:
bypass-github-export-checks

- refactor compiler spec for more backends
- support HTP features of multi-context, dlbc
- test cases for multi-context

Pull Request resolved: pytorch#2450

Reviewed By: kirklandsign

Differential Revision: D54932332

Pulled By: cccclai

fbshipit-source-id: 5570897c1ab3833fc1c0b45d0331469aece4cf31
  • Loading branch information
haowhsu-quic authored and facebook-github-bot committed Mar 17, 2024
1 parent 246ed45 commit 84cd2bb
Show file tree
Hide file tree
Showing 23 changed files with 697 additions and 214 deletions.
1 change: 1 addition & 0 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

not_supported_operator = [
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.full.default,
]
Expand Down
90 changes: 35 additions & 55 deletions backends/qualcomm/runtime/QnnManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,66 +24,53 @@ QnnManager::~QnnManager() {
QnnManager::QnnManager(
const QnnExecuTorchOptions* options,
const QnnExecuTorchContextBinary& qnn_executorch_context_binary)
: backend_type_(options->backend_type()),
library_path_(options->library_path()->c_str()),
skel_library_dir_(options->skel_library_dir()->c_str()),
tensor_dump_output_path_(options->tensor_dump_output_path()->c_str()),
graph_name_(options->graph_name()->c_str()),
soc_info_(options->soc_info()),
htp_options_(options->htp_options()),
log_level_(options->log_level()),
qnn_context_blob_(qnn_executorch_context_binary),
qnn_loaded_backend_(library_path_),
online_prepare_(options->online_prepare()) {
if (log_level_ >= QnnExecuTorchLogLevel::kLogLevelInfo) {
: qnn_context_blob_(qnn_executorch_context_binary),
qnn_loaded_backend_(""),
// options' life cycle is decided by compiler specs which is
// kept by executorch runtime framework
// please pay attention to any potential seg fault
options_(options) {
QnnExecuTorchBackendType backend_type =
options->backend_options()->backend_type();
std::string library_path = options->library_path()->str();

if (options->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) {
QNN_EXECUTORCH_LOG_INFO(
"backend_type: %s",
EnumNameQnnExecuTorchBackendType(options->backend_type()));
QNN_EXECUTORCH_LOG_INFO("graph_name: %s", options->graph_name()->c_str());
QNN_EXECUTORCH_LOG_INFO(
"library_path: %s", options->library_path()->c_str());
"soc_model in soc_info: %s",
EnumNameQcomChipset(options_->soc_info()->soc_model()));
QNN_EXECUTORCH_LOG_INFO(
"skel_library_dir: %s", options->skel_library_dir()->c_str());
"backend_type: %s", EnumNameQnnExecuTorchBackendType(backend_type));
QNN_EXECUTORCH_LOG_INFO("graph_name: %s", options_->graph_name()->c_str());
QNN_EXECUTORCH_LOG_INFO("library_path: %s", library_path.c_str());
QNN_EXECUTORCH_LOG_INFO(
"tensor_dump_output_path: %s",
options->tensor_dump_output_path()->c_str());
QNN_EXECUTORCH_LOG_INFO(
"log_level: %s", EnumNameQnnExecuTorchLogLevel(options->log_level()));
QNN_EXECUTORCH_LOG_INFO(
"soc_model in soc_info: %s",
EnumNameQcomChipset(options->soc_info()->soc_model()));
QNN_EXECUTORCH_LOG_INFO(
"htp_arch in htp_info: %s",
EnumNameHtpArch(options->soc_info()->htp_info()->htp_arch()));
options_->tensor_dump_output_path()->c_str());
QNN_EXECUTORCH_LOG_INFO(
"vtcm_size_in_mb in htp_info: %d",
options->soc_info()->htp_info()->vtcm_size_in_mb());
"log_level: %s", EnumNameQnnExecuTorchLogLevel(options_->log_level()));
QNN_EXECUTORCH_LOG_INFO(
"the size of qnn context binary: %d",
qnn_executorch_context_binary.nbytes);
QNN_EXECUTORCH_LOG_INFO(
"Is on-device graph construction: %d", options->online_prepare());
"Is on-device graph construction: %d", options_->online_prepare());
}
if (!skel_library_dir_.empty()) {
setenv("ADSP_LIBRARY_PATH", skel_library_dir_.c_str(), /*overwrite=*/1);
}
if (library_path_.empty()) {
switch (backend_type_) {

if (library_path.empty()) {
switch (backend_type) {
case QnnExecuTorchBackendType::kHtpBackend:
library_path_ = htp_library_name_;
library_path = htp_library_name_;
break;
case QnnExecuTorchBackendType::kDspBackend:
library_path_ = dsp_library_name_;
library_path = dsp_library_name_;
break;
case QnnExecuTorchBackendType::kGpuBackend:
library_path_ = gpu_library_name_;
library_path = gpu_library_name_;
break;
default:
QNN_EXECUTORCH_LOG_ERROR("Unknown backend type: %s", backend_type_);
QNN_EXECUTORCH_LOG_ERROR("Unknown backend type: %d", backend_type);
break;
}
}
qnn_loaded_backend_ = QnnImplementation(library_path_);
qnn_loaded_backend_ = QnnImplementation(library_path);
backend_params_ptr_ = std::make_unique<BackendConfigParameters>();
}

Expand All @@ -96,22 +83,15 @@ Error QnnManager::Init() {
ET_CHECK_OR_RETURN_ERROR(
LoadQnnLibrary() == Error::Ok, Internal, "Fail to load Qnn library");
logger_ = std::make_unique<QnnLogger>(
qnn_loaded_backend_, LoggingCallback, log_level_);
qnn_loaded_backend_, LoggingCallback, options_->log_level());
if (backend_params_ptr_->backend_init_state_ ==
BackendInitializeState::UNINITIALIZED) {
QNN_EXECUTORCH_LOG_INFO(
"Initialize Qnn backend "
"parameters for Qnn executorch backend type %d",
backend_type_);
options_->backend_options()->backend_type());
backend_params_ptr_ = QnnBackendFactory().Create(
qnn_loaded_backend_,
logger_.get(),
log_level_,
qnn_context_blob_,
backend_type_,
graph_name_,
soc_info_,
htp_options_);
qnn_loaded_backend_, logger_.get(), qnn_context_blob_, options_);
ET_CHECK_OR_RETURN_ERROR(
backend_params_ptr_->qnn_backend_ptr_->Configure() == Error::Ok,
Internal,
Expand Down Expand Up @@ -150,7 +130,7 @@ Error QnnManager::AllocateTensor() {
for (auto& tensor : output_tensors) {
std::shared_ptr<TensorWrapper> tensor_wrapper = CreateTensorWrapper(tensor);
tensor_wrapper->UpdateQnnTensorMeta(tensor);
if (!tensor_dump_output_path_.empty()) {
if (IsTensorDump()) {
tensor_wrapper->AllocateDataBuffer();
}
output_tensors_.emplace_back(std::move(tensor_wrapper));
Expand All @@ -163,7 +143,7 @@ Error QnnManager::AllocateTensor(
std::vector<std::shared_ptr<TensorWrapper>>& outputs) {
input_tensors_ = std::move(inputs);
for (auto& output_tensor : outputs) {
if (!tensor_dump_output_path_.empty()) {
if (IsTensorDump()) {
output_tensor->AllocateDataBuffer();
}
}
Expand All @@ -185,10 +165,10 @@ Error QnnManager::Execute(
return Error::Internal;
}

if (!tensor_dump_output_path_.empty()) {
if (IsTensorDump()) {
// TODO: Need to handle the graph which is partitioned.
// Maybe we could use graph name.
std::string dir = tensor_dump_output_path_ + "/Result/";
std::string dir = options_->tensor_dump_output_path()->str() + "/Result/";
CreateDirectory(dir);
QNN_EXECUTORCH_LOG_INFO("Dump tensor to the path: %s", dir.c_str());
for (std::size_t out_idx = 0; out_idx < output_tensor_structs.size();
Expand Down Expand Up @@ -227,7 +207,7 @@ bool QnnManager::IsAvailable() {
}

bool QnnManager::IsOnlinePrepare() {
return online_prepare_;
return options_->online_prepare();
}

bool QnnManager::IsNodeSupportedByBackend(
Expand Down
12 changes: 2 additions & 10 deletions backends/qualcomm/runtime/QnnManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class QnnManager {
bool IsAvailable();

bool IsTensorDump() {
return !tensor_dump_output_path_.empty();
return options_->tensor_dump_output_path()->size() > 0;
}

bool IsOnlinePrepare();
Expand All @@ -69,21 +69,13 @@ class QnnManager {
static constexpr const char* gpu_library_name_ = "libQnnGpu.so";
static constexpr const char* dsp_library_name_ = "libQnnDsp.so";

QnnExecuTorchBackendType backend_type_;
std::string library_path_;
std::string skel_library_dir_;
std::string tensor_dump_output_path_;
std::string graph_name_;
const SocInfo* soc_info_;
const QnnExecuTorchHtpBackendOptions* htp_options_;
QnnExecuTorchLogLevel log_level_;
QnnExecuTorchContextBinary qnn_context_blob_;
std::unique_ptr<BackendConfigParameters> backend_params_ptr_;
QnnImplementation qnn_loaded_backend_;
std::unique_ptr<QnnLogger> logger_;
const QnnExecuTorchOptions* options_;
std::vector<std::shared_ptr<TensorWrapper>> input_tensors_;
std::vector<std::shared_ptr<TensorWrapper>> output_tensors_;
bool online_prepare_;
};
} // namespace qnn
} // namespace executor
Expand Down
3 changes: 3 additions & 0 deletions backends/qualcomm/runtime/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ target_sources(qnn_context
${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpContext.h
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/QnnContextCommon.cpp
${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpContext.cpp
${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpContextCustomConfig.h
${HOST_ARCHITECTURE}/HtpContextCustomConfig.cpp
)

# qnn_backend_cache
Expand Down
35 changes: 23 additions & 12 deletions backends/qualcomm/runtime/backends/QnnBackendFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,27 @@ namespace qnn {
std::unique_ptr<BackendConfigParameters> QnnBackendFactory::Create(
const QnnImplementation& implementation,
QnnLogger* logger,
const QnnExecuTorchLogLevel& log_level,
const QnnExecuTorchContextBinary& qnn_context_blob,
const QnnExecuTorchBackendType& backend_type,
const std::string& graph_name,
const SocInfo* soc_info,
const QnnExecuTorchHtpBackendOptions* htp_options) {
const QnnExecuTorchOptions* options) {
auto backend_params = std::make_unique<BackendConfigParameters>();
switch (backend_type) {
case QnnExecuTorchBackendType::kHtpBackend:
if (log_level >= QnnExecuTorchLogLevel::kLogLevelInfo) {
switch (options->backend_options()->backend_type()) {
case QnnExecuTorchBackendType::kHtpBackend: {
auto htp_options = options->backend_options()->htp_options();
if (options->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) {
const std::string skel_library_dir =
htp_options->skel_library_dir()->str();
if (!skel_library_dir.empty()) {
setenv(
"ADSP_LIBRARY_PATH", skel_library_dir.c_str(), /*overwrite=*/1);
}
QNN_EXECUTORCH_LOG_INFO(
"skel_library_dir: %s", skel_library_dir.c_str());
QNN_EXECUTORCH_LOG_INFO(
"htp_arch in htp_info: %s",
EnumNameHtpArch(options->soc_info()->htp_info()->htp_arch()));
QNN_EXECUTORCH_LOG_INFO(
"vtcm_size_in_mb in htp_info: %d",
options->soc_info()->htp_info()->vtcm_size_in_mb());
QNN_EXECUTORCH_LOG_INFO(
"performance_mode in htp_options: %s",
EnumNameQnnExecuTorchHtpPerformanceMode(
Expand All @@ -41,7 +52,7 @@ std::unique_ptr<BackendConfigParameters> QnnBackendFactory::Create(
backend_params->qnn_backend_ptr_ =
std::make_unique<HtpBackend>(implementation, logger);
backend_params->qnn_device_ptr_ = std::make_unique<HtpDevice>(
implementation, logger, soc_info, htp_options);
implementation, logger, options->soc_info(), htp_options);

backend_params->qnn_context_ptr_ = std::make_unique<HtpContext>(
implementation,
Expand All @@ -53,12 +64,12 @@ std::unique_ptr<BackendConfigParameters> QnnBackendFactory::Create(
backend_params->qnn_graph_ptr_ = std::make_unique<HtpGraph>(
implementation,
backend_params->qnn_context_ptr_.get(),
graph_name,
soc_info,
options->graph_name()->str(),
options->soc_info(),
htp_options);
backend_params->backend_init_state_ = BackendInitializeState::INITIALIZED;
return backend_params;
break;
} break;
case QnnExecuTorchBackendType::kGpuBackend:
case QnnExecuTorchBackendType::kDspBackend:
case QnnExecuTorchBackendType::kUndefinedBackend:
Expand Down
6 changes: 1 addition & 5 deletions backends/qualcomm/runtime/backends/QnnBackendFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,8 @@ class QnnBackendFactory {
std::unique_ptr<BackendConfigParameters> Create(
const QnnImplementation& implementation,
QnnLogger* logger,
const QnnExecuTorchLogLevel& log_level,
const QnnExecuTorchContextBinary& qnn_context_blob,
const QnnExecuTorchBackendType& backend_type,
const std::string& graph_name,
const SocInfo* soc_info,
const QnnExecuTorchHtpBackendOptions* htp_options);
const QnnExecuTorchOptions* options);
};
} // namespace qnn
} // namespace executor
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/runtime/backends/QnnContextCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Error QnnContext::Configure() {
QNN_EXECUTORCH_LOG_ERROR("QNN context cache is invalid.");
return Error::Internal;
}
return Error::Ok;
return AfterConfigure();
}

Error QnnContext::GetContextBinary(
Expand Down
5 changes: 4 additions & 1 deletion backends/qualcomm/runtime/backends/QnnContextCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class QnnContext {
virtual ~QnnContext();
Error Configure();

Qnn_ContextHandle_t GetHandle() {
Qnn_ContextHandle_t GetHandle() const {
return handle_;
}

Expand All @@ -58,6 +58,9 @@ class QnnContext {
virtual Error MakeConfig(std::vector<const QnnContext_Config_t*>& config) {
return Error::Ok;
};
virtual Error AfterConfigure() {
return Error::Ok;
};

private:
Qnn_ContextHandle_t handle_;
Expand Down
49 changes: 49 additions & 0 deletions backends/qualcomm/runtime/backends/htpbackend/HtpContext.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) Qualcomm Innovation Center, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/qualcomm/runtime/Logging.h>
#include <executorch/backends/qualcomm/runtime/backends/htpbackend/HtpContext.h>

#include "HTP/QnnHtpCommon.h"
#include "Saver/QnnSaverCommon.h"

namespace torch {
namespace executor {
namespace qnn {

Error HtpContext::MakeConfig(std::vector<const QnnContext_Config_t*>& config) {
const std::vector<QnnContext_CustomConfig_t>& context_custom_config =
htp_context_custom_config_->CreateContextCustomConfig();

uint32_t num_custom_configs = context_custom_config.size();
context_config_.resize(num_custom_configs);
// +1 for null terminated
config.reserve(num_custom_configs + 1);

for (std::size_t i = 0; i < num_custom_configs; ++i) {
context_config_[i].option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
context_config_[i].customConfig = context_custom_config[i];
config.push_back(&context_config_[i]);
}

config.push_back(nullptr);
return Error::Ok;
}

Error HtpContext::AfterConfigure() {
// update sf_handle with first context handle encounterded as group handle
// TODO: should handle the thread safety if needed
if (sf_handle_ == 0x0) {
sf_handle_ = GetHandle();
}
return Error::Ok;
}

} // namespace qnn
} // namespace executor
} // namespace torch
Loading

0 comments on commit 84cd2bb

Please sign in to comment.