forked from open-mmlab/mmcv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature]: Add custom operators support for onnxruntime in mmcv (open…
…-mmlab#612) * add onnx support to roi_align and roi_pool * add softnms ort support * fix for lint * format cpp code with clang-format:google * add new empty line to the end of head files in onnxruntime * update to pytorch1.7 * add test of softnms to onnxruntime * fix for lint * remote print in ops/info.py * change import order, fix for flake8 * fix include * add assert torch>=1.7.0 * [doc]: add document for onnxruntime custom operator * update onnxruntime version to v1.5.1 for softnms * remove doc menu * Resolve lint for markdown * resolve naming style in onnxruntime_op.md * Use old cpp apis, optimize test_onnx.py * Fixing strings in tests/test_ops/test_onnx.py * code format with yapf * fix soft_nms parrot * add import in onnxruntime setup, avoid conflict * fix doc and add assert * change cpp guard Co-authored-by: maningsheng <[email protected]>
- Loading branch information
1 parent
8b4e5de
commit 94810f2
Showing
13 changed files
with
607 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Custom operators for ONNX Runtime in MMCV | ||
|
||
## Introduction of ONNX Runtime | ||
|
||
**ONNX Runtime** is a cross-platform inferencing and training accelerator compatible with many popular ML/DNN frameworks. Check its [github](https://github.com/microsoft/onnxruntime) for more information. | ||
|
||
## Introduction of ONNX | ||
|
||
**ONNX** stands for **Open Neural Network Exchange**, which acts as *Intermediate Representation(IR)* for ML/DNN models from many frameworks. Check its [github](https://github.com/onnx/onnx) for more information. | ||
|
||
## Why include custom operators for ONNX Runtime in MMCV | ||
|
||
- To verify the correctness of exported ONNX models in ONNX Runtime. | ||
- To ease the deployment of ONNX models with custom operators from `mmcv.ops` in ONNX Runtime. | ||
|
||
## List of operators for ONNX Runtime supported in MMCV | ||
|
||
| Operator | CPU | GPU | Note | | ||
| :------: | :---: | :---: | :---: | | ||
| SoftNMS | Y | N | None | | ||
|
||
## How to build custom operators for ONNX Runtime | ||
|
||
*Please be noted that only **onnxruntime>=1.5.1** of CPU version on Linux platform is tested by now.* | ||
|
||
### Prerequisite | ||
|
||
- Clone repository | ||
|
||
```bash | ||
git clone https://github.com/open-mmlab/mmcv.git | ||
``` | ||
|
||
- Download `onnxruntime-linux-x64-1.5.1.tgz` from ONNX Runtime [releases](https://github.com/microsoft/onnxruntime/releases/tag/v1.5.1), extract it, expose `ONNXRUNTIME_DIR` and finally add the lib path to `LD_LIBRARY_PATH` as below: | ||
|
||
```bash | ||
|
||
wget https://github.com/microsoft/onnxruntime/releases/download/v1.5.1/onnxruntime-linux-x64-1.5.1.tgz | ||
|
||
tar -zxvf onnxruntime-linux-x64-1.5.1.tgz | ||
cd onnxruntime-linux-x64-1.5.1 | ||
export ONNXRUNTIME_DIR=$(pwd) | ||
export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH | ||
``` | ||
|
||
### Build on Linux | ||
|
||
```bash | ||
cd mmcv # to MMCV root directory | ||
MMCV_WITH_OPS=1 MMCV_WITH_ORT=1 pip install -e . | ||
``` | ||
|
||
## How to do inference using exported ONNX models with custom operators in ONNX Runtime in python | ||
|
||
Install ONNX Runtime with `pip` | ||
|
||
```bash | ||
pip install onnxruntime==1.5.1 | ||
``` | ||
|
||
Inference Demo | ||
|
||
```python | ||
import os | ||
|
||
import numpy as np | ||
import onnxruntime as ort | ||
|
||
from mmcv.ops import get_onnxruntime_op_path | ||
|
||
ort_custom_op_path = get_onnxruntime_op_path() | ||
assert os.path.exists(ort_custom_op_path) | ||
session_options = ort.SessionOptions() | ||
session_options.register_custom_ops_library(ort_custom_op_path) | ||
# exported ONNX model with custom operators | ||
onnx_file = 'sample.onnx' | ||
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32) | ||
sess = ort.InferenceSession(onnx_file, session_options) | ||
onnx_results = sess.run(None, {'input' : input_data}) | ||
``` | ||
|
||
## How to add a new custom operator for ONNX Runtime in MMCV | ||
|
||
### Reminder | ||
|
||
- The custom operator is not included in [supported operator list](https://github.com/microsoft/onnxruntime/blob/master/docs/OperatorKernels.md) in ONNX Runtime. | ||
- The custom operator should be able to be exported to ONNX. | ||
|
||
### Main procedures | ||
|
||
Take custom operator `soft_nms` for example. | ||
|
||
1. Add header `soft_nms.h` to ONNX Runtime include directory `mmcv/ops/csrc/onnxruntime/` | ||
2. Add source `soft_nms.cpp` to ONNX Runtime source directory `mmcv/ops/csrc/onnxruntime/cpu/` | ||
3. Register `soft_nms` operator in [onnxruntime_register.cpp](../mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp) | ||
|
||
```c++ | ||
#include "soft_nms.h" | ||
|
||
SoftNmsOp c_SoftNmsOp; | ||
|
||
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_SoftNmsOp)) { | ||
return status; | ||
} | ||
``` | ||
|
||
4. Add unit test into `tests/test_ops/test_onnx.py` | ||
Check [here](../tests/test_ops/test_onnx.py) for examples. | ||
|
||
**Finally, welcome to send us PR of adding custom operators for ONNX Runtime in MMCV.** :nerd_face: | ||
|
||
## Known Issues | ||
|
||
- None | ||
|
||
## References | ||
|
||
- [How to export Pytorch model with custom op to ONNX and run it in ONNX Runtime](https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md) | ||
- [How to add a custom operator/kernel in ONNX Runtime](https://github.com/microsoft/onnxruntime/blob/master/docs/AddingCustomOp.md) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#include "onnxruntime_register.h" | ||
|
||
#include "ort_mmcv_utils.h" | ||
#include "soft_nms.h" | ||
|
||
const char *c_MMCVOpDomain = "mmcv"; | ||
SoftNmsOp c_SoftNmsOp; | ||
|
||
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, | ||
const OrtApiBase *api) { | ||
OrtCustomOpDomain *domain = nullptr; | ||
const OrtApi *ortApi = api->GetApi(ORT_API_VERSION); | ||
|
||
if (auto status = ortApi->CreateCustomOpDomain(c_MMCVOpDomain, &domain)) { | ||
return status; | ||
} | ||
|
||
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_SoftNmsOp)) { | ||
return status; | ||
} | ||
|
||
return ortApi->AddCustomOpDomain(options, domain); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
#include "soft_nms.h" | ||
|
||
#include <assert.h> | ||
|
||
#include <algorithm> | ||
#include <cmath> | ||
|
||
#include "../ort_mmcv_utils.h" | ||
|
||
SoftNmsKernel::SoftNmsKernel(OrtApi api, const OrtKernelInfo *info) | ||
: api_(api), ort_(api_), info_(info) { | ||
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold"); | ||
sigma_ = ort_.KernelInfoGetAttribute<float>(info, "sigma"); | ||
min_score_ = ort_.KernelInfoGetAttribute<float>(info, "min_score"); | ||
method_ = ort_.KernelInfoGetAttribute<int64_t>(info, "method"); | ||
offset_ = ort_.KernelInfoGetAttribute<int64_t>(info, "offset"); | ||
|
||
// create allocator | ||
allocator_ = Ort::AllocatorWithDefaultOptions(); | ||
} | ||
|
||
void SoftNmsKernel::Compute(OrtKernelContext *context) { | ||
typedef float T; | ||
|
||
const T iou_threshold = T(iou_threshold_); | ||
const T sigma = T(sigma_); | ||
const T min_score = T(min_score_); | ||
const int method = int(method_); | ||
const T offset = T(offset_); | ||
|
||
const OrtValue *boxes = ort_.KernelContext_GetInput(context, 0); | ||
const T *boxes_data = | ||
reinterpret_cast<const float *>(ort_.GetTensorData<T>(boxes)); | ||
const OrtValue *scores = ort_.KernelContext_GetInput(context, 1); | ||
const T *scores_data = | ||
reinterpret_cast<const float *>(ort_.GetTensorData<T>(scores)); | ||
|
||
OrtTensorDimensions boxes_dim(ort_, boxes); | ||
OrtTensorDimensions scores_dim(ort_, scores); | ||
|
||
int64_t nboxes = boxes_dim[0]; | ||
assert(boxes_dim[1] == 4); | ||
|
||
// allocate tmp memory | ||
T *tmp_boxes = (T *)allocator_.Alloc(sizeof(T) * nboxes * 4); | ||
T *x1 = tmp_boxes; | ||
T *y1 = tmp_boxes + 1; | ||
T *x2 = tmp_boxes + 2; | ||
T *y2 = tmp_boxes + 3; | ||
T *sc = (T *)allocator_.Alloc(sizeof(T) * nboxes); | ||
T *areas = (T *)allocator_.Alloc(sizeof(T) * nboxes); | ||
T *de = (T *)allocator_.Alloc(sizeof(T) * nboxes * 5); | ||
int64_t *inds = (int64_t *)allocator_.Alloc(sizeof(int64_t) * nboxes); | ||
|
||
memcpy(tmp_boxes, boxes_data, sizeof(T) * nboxes * 4); | ||
memcpy(sc, scores_data, sizeof(T) * nboxes); | ||
|
||
// init inds as arange(nboxes) | ||
std::generate(inds, inds + nboxes, [n = 0]() mutable { return n++; }); | ||
|
||
// area = (x2-x1+offset)*(y2-y1+offset) | ||
for (int64_t i = 0; i < nboxes; i++) { | ||
areas[i] = | ||
(x2[i * 4] - x1[i * 4] + offset) * (y2[i * 4] - y1[i * 4] + offset); | ||
} | ||
|
||
int64_t pos = 0; | ||
|
||
for (int64_t i = 0; i < nboxes; i++) { | ||
auto max_score = sc[i]; | ||
auto max_pos = i; | ||
|
||
pos = i + 1; | ||
// get max box | ||
while (pos < nboxes) { | ||
if (max_score < sc[pos]) { | ||
max_score = sc[pos]; | ||
max_pos = pos; | ||
} | ||
pos = pos + 1; | ||
} | ||
// swap | ||
auto ix1 = de[i * 5 + 0] = x1[max_pos * 4]; | ||
auto iy1 = de[i * 5 + 1] = y1[max_pos * 4]; | ||
auto ix2 = de[i * 5 + 2] = x2[max_pos * 4]; | ||
auto iy2 = de[i * 5 + 3] = y2[max_pos * 4]; | ||
auto iscore = de[i * 5 + 4] = sc[max_pos]; | ||
auto iarea = areas[max_pos]; | ||
auto iind = inds[max_pos]; | ||
x1[max_pos * 4] = x1[i * 4]; | ||
y1[max_pos * 4] = y1[i * 4]; | ||
x2[max_pos * 4] = x2[i * 4]; | ||
y2[max_pos * 4] = y2[i * 4]; | ||
sc[max_pos] = sc[i]; | ||
areas[max_pos] = areas[i]; | ||
inds[max_pos] = inds[i]; | ||
x1[i * 4] = ix1; | ||
y1[i * 4] = iy1; | ||
x2[i * 4] = ix2; | ||
y2[i * 4] = iy2; | ||
sc[i] = iscore; | ||
areas[i] = iarea; | ||
inds[i] = iind; | ||
|
||
pos = i + 1; | ||
while (pos < nboxes) { | ||
auto xx1 = std::max(ix1, x1[pos * 4]); | ||
auto yy1 = std::max(iy1, y1[pos * 4]); | ||
auto xx2 = std::min(ix2, x2[pos * 4]); | ||
auto yy2 = std::min(iy2, y2[pos * 4]); | ||
|
||
auto w = std::max(0.f, xx2 - xx1 + offset); | ||
auto h = std::max(0.f, yy2 - yy1 + offset); | ||
auto inter = w * h; | ||
auto ovr = inter / (iarea + areas[pos] - inter); | ||
|
||
float weight = 1.; | ||
if (method == 0) { | ||
if (ovr >= iou_threshold) weight = 0; | ||
} else if (method == 1) { | ||
if (ovr >= iou_threshold) weight = 1 - ovr; | ||
} else if (method == 2) { | ||
weight = std::exp(-(ovr * ovr) / sigma); | ||
} | ||
sc[pos] *= weight; | ||
// if box score falls below threshold, discard the box by | ||
// swapping with last box update N | ||
if (sc[pos] < min_score) { | ||
x1[pos * 4] = x1[(nboxes - 1) * 4]; | ||
y1[pos * 4] = y1[(nboxes - 1) * 4]; | ||
x2[pos * 4] = x2[(nboxes - 1) * 4]; | ||
y2[pos * 4] = y2[(nboxes - 1) * 4]; | ||
sc[pos] = sc[nboxes - 1]; | ||
areas[pos] = areas[nboxes - 1]; | ||
inds[pos] = inds[nboxes - 1]; | ||
nboxes = nboxes - 1; | ||
pos = pos - 1; | ||
} | ||
pos = pos + 1; | ||
} | ||
} | ||
|
||
std::vector<int64_t> dets_dim({nboxes, 5}); | ||
OrtValue *dets = ort_.KernelContext_GetOutput(context, 0, dets_dim.data(), | ||
dets_dim.size()); | ||
T *dets_data = ort_.GetTensorMutableData<T>(dets); | ||
|
||
std::vector<int64_t> inds_dim({nboxes}); | ||
OrtValue *inds_ov = ort_.KernelContext_GetOutput(context, 1, inds_dim.data(), | ||
inds_dim.size()); | ||
int64_t *inds_data = ort_.GetTensorMutableData<int64_t>(inds_ov); | ||
|
||
memcpy(dets_data, de, sizeof(T) * nboxes * 5); | ||
memcpy(inds_data, inds, sizeof(int64_t) * nboxes); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef ONNXRUNTIME_REGISTER_H | ||
#define ONNXRUNTIME_REGISTER_H | ||
#include <onnxruntime_c_api.h> | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, | ||
const OrtApiBase *api); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
#endif // ONNXRUNTIME_REGISTER_H |
44 changes: 44 additions & 0 deletions
44
mmcv/ops/csrc/onnxruntime/onnxruntime_session_options_config_keys.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#ifndef ONNXRUNTIME_SESSION_OPTIONS_CONFIG_KEYS_H | ||
#define ONNXRUNTIME_SESSION_OPTIONS_CONFIG_KEYS_H | ||
|
||
/* | ||
* This file defines SessionOptions Config Keys and format of the Config Values. | ||
* | ||
* The Naming Convention for a SessionOptions Config Key, | ||
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]" | ||
* Such as "ep.cuda.use_arena" | ||
* The Config Key cannot be empty | ||
* The maximum length of the Config Key is 128 | ||
* | ||
* The string format of a SessionOptions Config Value is defined individually | ||
* for each Config. The maximum length of the Config Value is 1024 | ||
*/ | ||
|
||
// Key for disable PrePacking, | ||
// If the config value is set to "1" then the prepacking is disabled, otherwise | ||
// prepacking is enabled (default value) | ||
static const char* const kOrtSessionOptionsConfigDisablePrepacking = | ||
"session.disable_prepacking"; | ||
|
||
// A value of "1" means allocators registered in the env will be used. "0" means | ||
// the allocators created in the session will be used. Use this to override the | ||
// usage of env allocators on a per session level. | ||
static const char* const kOrtSessionOptionsConfigUseEnvAllocators = | ||
"session.use_env_allocators"; | ||
|
||
// Set to 'ORT' (case sensitive) to load an ORT format model. | ||
// If unset, model type will default to ONNX unless inferred from filename | ||
// ('.ort' == ORT format) or bytes to be ORT | ||
static const char* const kOrtSessionOptionsConfigLoadModelFormat = | ||
"session.load_model_format"; | ||
|
||
// Set to 'ORT' (case sensitive) to save optimized model in ORT format when | ||
// SessionOptions.optimized_model_path is set. If unset, format will default to | ||
// ONNX unless optimized_model_filepath ends in '.ort'. | ||
static const char* const kOrtSessionOptionsConfigSaveModelFormat = | ||
"session.save_model_format"; | ||
|
||
#endif // ONNXRUNTIME_SESSION_OPTIONS_CONFIG_KEYS_H |
Oops, something went wrong.