Skip to content

Commit 52a2adb

Browse files
IvanKobzarevfacebook-github-bot
authored andcommitted
[android] test_app example linking to pytorch_android aar content (pytorch#39587)
Summary: Pull Request resolved: pytorch#39587 Example of using direct linking to pytorch_jni library from aar and updating android/README.md with the tutorial how to do it. Adding `nativeBuild` dimension to `test_app`, using direct aar dependencies, as headers packaging is not landed yet, excluding `nativeBuild` from building by default for CI. Additional change to `scripts/build_pytorch_android.sh`: Skipping clean task here as android gradle plugin 3.3.2 exteralNativeBuild has problems with it when abiFilters are specified. Will be returned back in the following diffs with upgrading of gradle and android gradle plugin versions. Test Plan: Imported from OSS Differential Revision: D22118945 Pulled By: IvanKobzarev fbshipit-source-id: 31c54b49b1f262cbe5f540461d3406f74851db6c
1 parent 954a59a commit 52a2adb

File tree

7 files changed

+332
-4
lines changed

7 files changed

+332
-4
lines changed

android/README.md

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,116 @@ As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/an
106106

107107
You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly.
108108

109-
## More Details
109+
## Linking to prebuilt libtorch library from gradle dependency
110+
111+
In some cases, you may want to use libtorch from your android native build.
112+
You can do it without building libtorch android, using native libraries from PyTorch android gradle dependency.
113+
For that, you will need to add the next lines to your gradle build.
114+
```
115+
android {
116+
...
117+
configurations {
118+
extractForNativeBuild
119+
}
120+
...
121+
compileOptions {
122+
externalNativeBuild {
123+
cmake {
124+
arguments "-DANDROID_STL=c++_shared"
125+
}
126+
}
127+
}
128+
...
129+
externalNativeBuild {
130+
cmake {
131+
path "CMakeLists.txt"
132+
}
133+
}
134+
}
135+
136+
dependencies {
137+
extractForNativeBuild('org.pytorch:pytorch_android:1.6.0')
138+
}
139+
140+
task extractAARForNativeBuild {
141+
doLast {
142+
configurations.extractForNativeBuild.files.each {
143+
def file = it.absoluteFile
144+
copy {
145+
from zipTree(file)
146+
into "$buildDir/$file.name"
147+
include "headers/**"
148+
include "jni/**"
149+
}
150+
}
151+
}
152+
}
153+
154+
tasks.whenTaskAdded { task ->
155+
if (task.name.contains('externalNativeBuild')) {
156+
task.dependsOn(extractAARForNativeBuild)
157+
}
158+
}
159+
```
160+
161+
pytorch_android aar contains headers to link in `headers` folder and native libraries in `jni/$ANDROID_ABI/`.
162+
As PyTorch native libraries use `ANDROID_STL` - we should use `ANDROID_STL=c++_shared` to have only one loaded binary of STL.
163+
164+
The added task will unpack them to gradle build directory.
165+
166+
In your native build you can link to them adding these lines to your CMakeLists.txt:
167+
168+
169+
```
170+
# Relative path of gradle build directory to CMakeLists.txt
171+
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
172+
173+
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
174+
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
175+
176+
set(BUILD_SUBDIR ${ANDROID_ABI})
177+
target_include_directories(${PROJECT_NAME} PRIVATE
178+
${PYTORCH_INCLUDE_DIRS}
179+
)
180+
181+
find_library(PYTORCH_LIBRARY pytorch_jni
182+
PATHS ${PYTORCH_LINK_DIRS}
183+
NO_CMAKE_FIND_ROOT_PATH)
184+
185+
target_link_libraries(${PROJECT_NAME}
186+
${PYTORCH_LIBRARY})
187+
188+
```
189+
If your CMakeLists.txt file is located in the same directory as your build.gradle, `set(build_DIR ${CMAKE_SOURCE_DIR}/build)` should work for you. But if you have another location of it, you may need to change it.
190+
191+
After that, you can use libtorch C++ API from your native code.
192+
```
193+
#include <string>
194+
#include <ATen/NativeFunctions.h>
195+
#include <torch/script.h>
196+
namespace pytorch_testapp_jni {
197+
namespace {
198+
struct JITCallGuard {
199+
torch::autograd::AutoGradMode no_autograd_guard{false};
200+
torch::AutoNonVariableTypeMode non_var_guard{true};
201+
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
202+
};
203+
}
204+
205+
void loadAndForwardModel(const std::string& modelPath) {
206+
JITCallGuard guard;
207+
torch::jit::Module module = torch::jit::load(modelPath);
208+
module.eval();
209+
torch::Tensor t = torch::randn({1, 3, 224, 224});
210+
c10::IValue t_out = module.forward({t});
211+
}
212+
}
213+
```
214+
215+
To load torchscript model for mobile we need some special setup which is placed in `struct JITCallGuard` in this example. It may change in future, you can track the latest changes keeping an eye in our [pytorch android jni code]([https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp#L28)
216+
217+
[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app)
218+
219+
## PyTorch Android API Javadoc
110220

111221
You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/docs/stable/packages.html).
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
cmake_minimum_required(VERSION 3.4.1)
2+
set(PROJECT_NAME pytorch_testapp_jni)
3+
project(${PROJECT_NAME} CXX)
4+
set(CMAKE_CXX_STANDARD 14)
5+
set(CMAKE_VERBOSE_MAKEFILE ON)
6+
7+
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
8+
9+
set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
10+
message(STATUS "ANDROID_STL:${ANDROID_STL}")
11+
file(GLOB pytorch_testapp_SOURCES
12+
${pytorch_testapp_cpp_DIR}/pytorch_testapp_jni.cpp
13+
)
14+
15+
add_library(${PROJECT_NAME} SHARED
16+
${pytorch_testapp_SOURCES}
17+
)
18+
19+
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
20+
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
21+
22+
target_compile_options(${PROJECT_NAME} PRIVATE
23+
-fexceptions
24+
)
25+
26+
set(BUILD_SUBDIR ${ANDROID_ABI})
27+
28+
target_include_directories(${PROJECT_NAME} PRIVATE
29+
${PYTORCH_INCLUDE_DIRS}
30+
)
31+
32+
find_library(PYTORCH_LIBRARY pytorch_jni
33+
PATHS ${PYTORCH_LINK_DIRS}
34+
NO_CMAKE_FIND_ROOT_PATH)
35+
36+
target_link_libraries(${PROJECT_NAME}
37+
${PYTORCH_LIBRARY}
38+
log)

android/test_app/app/build.gradle

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ repositories {
1111
}
1212

1313
android {
14+
configurations {
15+
extractForNativeBuild
16+
}
1417
compileOptions {
1518
sourceCompatibility 1.8
1619
targetCompatibility 1.8
@@ -28,17 +31,28 @@ android {
2831
}
2932
externalNativeBuild {
3033
cmake {
34+
abiFilters ABI_FILTERS.split(",")
3135
arguments "-DANDROID_STL=c++_shared"
3236
}
3337
}
3438
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"")
3539
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
3640
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
41+
buildConfigField("boolean", "NATIVE_BUILD", 'false')
3742
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
3843
}
3944
buildTypes {
4045
debug {
4146
minifyEnabled false
47+
debuggable true
48+
}
49+
release {
50+
minifyEnabled false
51+
}
52+
}
53+
externalNativeBuild {
54+
cmake {
55+
path "CMakeLists.txt"
4256
}
4357
}
4458
flavorDimensions "model", "build", "activity"
@@ -66,6 +80,10 @@ android {
6680
aar {
6781
dimension "build"
6882
}
83+
nativeBuild {
84+
dimension "build"
85+
buildConfigField("boolean", "NATIVE_BUILD", "true")
86+
}
6987
camera {
7088
dimension "activity"
7189
addManifestPlaceholders([MAIN_ACTIVITY: "org.pytorch.testapp.CameraActivity"])
@@ -79,7 +97,6 @@ android {
7997
}
8098
}
8199
}
82-
83100
}
84101
}
85102
packagingOptions {
@@ -92,7 +109,8 @@ android {
92109
def names = variant.flavors*.name
93110
if (names.contains("nightly")
94111
|| names.contains("camera")
95-
|| names.contains("aar")) {
112+
|| names.contains("aar")
113+
|| names.contains("nativeBuild")) {
96114
setIgnore(true)
97115
}
98116
}
@@ -101,9 +119,16 @@ android {
101119

102120
dependencies {
103121
implementation 'com.android.support:appcompat-v7:28.0.0'
122+
implementation 'com.facebook.soloader:nativeloader:0.8.0'
104123

105124
localImplementation project(':pytorch_android')
106125
localImplementation project(':pytorch_android_torchvision')
126+
127+
nativeBuildImplementation(name: 'pytorch_android-release', ext: 'aar')
128+
nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
129+
// Commented due to dependency on local copy of pytorch_android aar to aars folder
130+
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
131+
107132
nightlyImplementation 'org.pytorch:pytorch_android:1.6.0-SNAPSHOT'
108133
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:1.6.0-SNAPSHOT'
109134

@@ -117,3 +142,23 @@ dependencies {
117142
cameraImplementation "androidx.camera:camera-camera2:$camerax_version"
118143
cameraImplementation 'com.google.android.material:material:1.0.0-beta01'
119144
}
145+
146+
task extractAARForNativeBuild {
147+
doLast {
148+
configurations.extractForNativeBuild.files.each {
149+
def file = it.absoluteFile
150+
copy {
151+
from zipTree(file)
152+
into "$buildDir/$file.name"
153+
include "headers/**"
154+
include "jni/**"
155+
}
156+
}
157+
}
158+
}
159+
160+
tasks.whenTaskAdded { task ->
161+
if (task.name.contains('externalNativeBuild')) {
162+
task.dependsOn(extractAARForNativeBuild)
163+
}
164+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#include <android/log.h>
2+
#include <pthread.h>
3+
#include <unistd.h>
4+
#include <cassert>
5+
#include <cmath>
6+
#include <vector>
7+
#define ALOGI(...) \
8+
__android_log_print(ANDROID_LOG_INFO, "PyTorchTestAppJni", __VA_ARGS__)
9+
#define ALOGE(...) \
10+
__android_log_print(ANDROID_LOG_ERROR, "PyTorchTestAppJni", __VA_ARGS__)
11+
12+
#include "jni.h"
13+
14+
#include <torch/script.h>
15+
16+
namespace pytorch_testapp_jni {
17+
namespace {
18+
19+
template <typename T>
20+
void log(const char* m, T t) {
21+
std::ostringstream os;
22+
os << t << std::endl;
23+
ALOGI("%s %s", m, os.str().c_str());
24+
}
25+
26+
struct JITCallGuard {
27+
torch::autograd::AutoGradMode no_autograd_guard{false};
28+
torch::AutoNonVariableTypeMode non_var_guard{true};
29+
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
30+
};
31+
} // namespace
32+
33+
static void loadAndForwardModel(JNIEnv* env, jclass, jstring jModelPath) {
34+
const char* modelPath = env->GetStringUTFChars(jModelPath, 0);
35+
assert(modelPath);
36+
37+
// To load torchscript model for mobile we need set these guards,
38+
// because mobile build doesn't support features like autograd for smaller
39+
// build size which is placed in `struct JITCallGuard` in this example. It may
40+
// change in future, you can track the latest changes keeping an eye in
41+
// android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
42+
JITCallGuard guard;
43+
torch::jit::Module module = torch::jit::load(modelPath);
44+
module.eval();
45+
torch::Tensor t = torch::randn({1, 3, 224, 224});
46+
log("input tensor:", t);
47+
c10::IValue t_out = module.forward({t});
48+
log("output tensor:", t_out);
49+
env->ReleaseStringUTFChars(jModelPath, modelPath);
50+
}
51+
} // namespace pytorch_testapp_jni
52+
53+
JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
54+
JNIEnv* env;
55+
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
56+
return JNI_ERR;
57+
}
58+
59+
jclass c =
60+
env->FindClass("org/pytorch/testapp/LibtorchNativeClient$NativePeer");
61+
if (c == nullptr) {
62+
return JNI_ERR;
63+
}
64+
65+
static const JNINativeMethod methods[] = {
66+
{"loadAndForwardModel",
67+
"(Ljava/lang/String;)V",
68+
(void*)pytorch_testapp_jni::loadAndForwardModel},
69+
};
70+
int rc = env->RegisterNatives(
71+
c, methods, sizeof(methods) / sizeof(JNINativeMethod));
72+
73+
if (rc != JNI_OK) {
74+
return rc;
75+
}
76+
77+
return JNI_VERSION_1_6;
78+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package org.pytorch.testapp;
2+
import com.facebook.soloader.nativeloader.NativeLoader;
3+
import com.facebook.soloader.nativeloader.SystemDelegate;
4+
5+
public final class LibtorchNativeClient {
6+
7+
public static void loadAndForwardModel(final String modelPath) {
8+
NativePeer.loadAndForwardModel(modelPath);
9+
}
10+
11+
private static class NativePeer {
12+
static {
13+
if (!NativeLoader.isInitialized()) {
14+
NativeLoader.init(new SystemDelegate());
15+
}
16+
NativeLoader.loadLibrary("pytorch_testapp_jni");
17+
}
18+
19+
private static native void loadAndForwardModel(final String modelPath);
20+
}
21+
}

0 commit comments

Comments
 (0)