Skip to content

Commit

Permalink
Added methods to fill arbitrarily shaped tensors from java runtime en…
Browse files Browse the repository at this point in the history
…vironment
  • Loading branch information
soyers committed Oct 23, 2016
1 parent ec7f37e commit a775c53
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,14 @@ public TensorFlowInferenceInterface() {
public native void close();

// Methods for creating a native Tensor and filling it with values.
public native void fillNodeFloatWithDimensions(String inputName, int[] dims, float[] values);
public native void fillNodeIntWithDimensions(String inputName, int[] dims, int[] values);
public native void fillNodeDoubleWithDimensions(String inputName, int[] dims, double[] values);
@Deprecated
public native void fillNodeFloat(String inputName, int x, int y, int z, int d, float[] values);
@Deprecated
public native void fillNodeInt(String inputName, int x, int y, int z, int d, int[] values);
@Deprecated
public native void fillNodeDouble(String inputName, int x, int y, int z, int d, double[] values);

public native void readNodeFloat(String outputName, float[] values);
Expand Down
24 changes: 22 additions & 2 deletions tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
Expand Down Expand Up @@ -221,11 +222,30 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
}

// TODO(andrewharp): Use memcpy to fill/read nodes.
#define FILL_NODE_METHOD_DEPRECATED(DTYPE, JAVA_DTYPE, TENSOR_DTYPE) \
FILL_NODE_SIGNATURE_DEPRECATED(DTYPE, JAVA_DTYPE) { \
jintArray dimArray = env->NewIntArray(4); \
jint *dimArray_ptr = env->GetIntArrayElements(dimArray, NULL); \
dimArray_ptr[0] = x; dimArray_ptr[1] = y; \
dimArray_ptr[2] = z; dimArray_ptr[3] = d; \
TENSORFLOW_METHOD(fillNode##DTYPE##WithDimensions)( \
env, thiz, node_name, dimArray, arr); \
}

#define FILL_NODE_METHOD(DTYPE, JAVA_DTYPE, TENSOR_DTYPE) \
FILL_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) { \
jint *nativeDims = env->GetIntArrayElements(dims, 0); \
jsize size = env->GetArrayLength(dims); \
int64 castDims[size]; \
for (int i = 0; i < size; ++i) { \
castDims[i] = static_cast<int64>(nativeDims[i]); \
} \
SessionVariables* vars = GetSessionVars(env, thiz); \
tensorflow::Tensor input_tensor(TENSOR_DTYPE, \
tensorflow::TensorShape({x, y, z, d})); \
tensorflow::Tensor input_tensor( \
TENSOR_DTYPE, \
tensorflow::TensorShape( \
gtl::ArraySlice<int64>(castDims, \
size))); \
auto tensor_mapped = input_tensor.flat<JAVA_DTYPE>(); \
jboolean iCopied = JNI_FALSE; \
j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(arr, &iCopied); \
Expand Down
12 changes: 10 additions & 2 deletions tensorflow/contrib/android/jni/tensorflow_inference_jni.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@ extern "C" {
#define TENSORFLOW_METHOD(METHOD_NAME) \
Java_org_tensorflow_contrib_android_TensorFlowInferenceInterface_##METHOD_NAME // NOLINT

#define FILL_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) \
#define FILL_NODE_SIGNATURE_DEPRECATED(DTYPE, JAVA_DTYPE) \
JNIEXPORT void TENSORFLOW_METHOD(fillNode##DTYPE)( \
JNIEnv * env, jobject thiz, jstring node_name, jint x, jint y, jint z, \
jint d, j##JAVA_DTYPE##Array arr)
jint d, j##JAVA_DTYPE##Array arr) \

#define FILL_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) \
JNIEXPORT void TENSORFLOW_METHOD(fillNode##DTYPE##WithDimensions)( \
JNIEnv * env, jobject thiz, jstring node_name, jintArray dims, \
j##JAVA_DTYPE##Array arr)

#define READ_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) \
JNIEXPORT jint TENSORFLOW_METHOD(readNode##DTYPE)( \
Expand All @@ -48,6 +53,9 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(

JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz);

FILL_NODE_SIGNATURE_DEPRECATED(Float, float);
FILL_NODE_SIGNATURE_DEPRECATED(Int, int);
FILL_NODE_SIGNATURE_DEPRECATED(Double, double);
FILL_NODE_SIGNATURE(Float, float);
FILL_NODE_SIGNATURE(Int, int);
FILL_NODE_SIGNATURE(Double, double);
Expand Down

0 comments on commit a775c53

Please sign in to comment.