Skip to content

Commit

Permalink
Allow use of ReversePackedSegs operator in CUDA context
Browse files Browse the repository at this point in the history
Summary: ReversePackedSegs operator for CUDA. Input "lengths" (static integers) required to be in CPU memory.

Differential Revision: D4661281

fbshipit-source-id: c800c316c34015ba8e732dcbcaa8c4edaffdfeab
  • Loading branch information
jhcross authored and facebook-github-bot committed Mar 9, 2017
1 parent 31aa217 commit 83b76f7
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 8 deletions.
9 changes: 9 additions & 0 deletions caffe2/operators/reverse_packed_segs_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include "caffe2/core/context_gpu.h"
#include "reverse_packed_segs_op.h"

namespace caffe2 {
namespace {
REGISTER_CUDA_OPERATOR(ReversePackedSegs, ReversePackedSegsOp<CUDAContext>);

} // namespace
} // namespace caffe2
6 changes: 4 additions & 2 deletions caffe2/operators/reverse_packed_segs_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class ReversePackedSegsOp final : public Operator<Context> {

template <typename T>
bool DoRunWithType() {
if (Input(LENGTHS).template IsType<int>()) {
if (OperatorBase::Input<Tensor<CPUContext>>(LENGTHS)
.template IsType<int>()) {
DoRunWithLengthType<T, int>();
} else {
DoRunWithLengthType<T, long>();
Expand All @@ -34,7 +35,7 @@ class ReversePackedSegsOp final : public Operator<Context> {
template <typename T, typename LengthType>
void DoRunWithLengthType() {
const auto& data = Input(DATA);
const auto& lengths = Input(LENGTHS);
const auto& lengths = OperatorBase::Input<Tensor<CPUContext>>(LENGTHS);

CAFFE_ENFORCE(
data.ndim() == 3,
Expand All @@ -56,6 +57,7 @@ class ReversePackedSegsOp final : public Operator<Context> {

const T* data_ptr = data.template data<T>();
const LengthType* lengths_ptr = lengths.template data<LengthType>();

T* rev_data_ptr = output->template mutable_data<T>();
for (TIndex i = 0; i < batch_size; i++) {
const auto& seg_length = lengths_ptr[i];
Expand Down
14 changes: 11 additions & 3 deletions caffe2/python/hypothesis_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,11 @@ def runOpBenchmark(
device_option,
op,
inputs,
input_device_options={},
input_device_options=None,
iterations=10,
):
if input_device_options is None:
input_device_options = {}
op = copy.deepcopy(op)
op.device_option.CopyFrom(device_option)
net = caffe2_pb2.NetDef()
Expand Down Expand Up @@ -445,7 +447,7 @@ def assertReferenceChecks(
op,
inputs,
reference,
input_device_options={},
input_device_options=None,
threshold=1e-4,
output_to_grad=None,
grad_reference=None,
Expand Down Expand Up @@ -473,6 +475,9 @@ def softsign(X):
self.assertReferenceChecks(gc, op, [X], softsign)
"""
if input_device_options is None:
input_device_options = {}

op = copy.deepcopy(op)
op.device_option.CopyFrom(device_option)

Expand All @@ -483,6 +488,7 @@ def softsign(X):
b,
device_option=input_device_options.get(n, device_option)
)
print("Input", n, input_device_options.get(n, device_option))
net = core.Net("opnet")
net.Proto().op.extend([op])
test_shape_inference = False
Expand Down Expand Up @@ -538,9 +544,11 @@ def assertValidationChecks(
op,
inputs,
validator,
input_device_options={},
input_device_options=None,
as_kwargs=True
):
if input_device_options is None:
input_device_options = {}
if as_kwargs:
assert len(set(list(op.input) + list(op.output))) == \
len(op.input) + len(op.output), \
Expand Down
8 changes: 5 additions & 3 deletions caffe2/python/operator_test/sequence_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from caffe2.python import core
from hypothesis import given
import caffe2.python.hypothesis_test_util as hu
from caffe2.proto import caffe2_pb2
import hypothesis.strategies as st
import numpy as np
from functools import partial
Expand Down Expand Up @@ -84,11 +85,11 @@ def _gather_padding_ref(start_pad_width, end_pad_width, data, lengths):
pad_width = start_pad_width + end_pad_width
ptr = 0
for length in lengths:
for i in range(start_pad_width):
for _ in range(start_pad_width):
start_padding += data[ptr]
ptr += 1
ptr += length - pad_width
for i in range(end_pad_width):
for _ in range(end_pad_width):
end_padding += data[ptr]
ptr += 1
return (start_padding, end_padding)
Expand Down Expand Up @@ -190,7 +191,7 @@ def test_gather_padding(self, start_pad_width, end_pad_width, args):
elements=st.floats(min_value=-np.inf,
max_value=np.inf),
min_value=1, max_value=10),
**hu.gcs_cpu_only)
**hu.gcs)
def test_reverse_packed_segs(self, data, gc, dc):
max_length = data.shape[0]
batch_size = data.shape[1]
Expand All @@ -217,6 +218,7 @@ def op_grad_ref(grad_out, outputs, inputs):
op=op,
inputs=[data, lengths],
reference=op_ref,
input_device_options={"lengths": core.DeviceOption(caffe2_pb2.CPU)},
output_to_grad='reversed_data',
grad_reference=op_grad_ref)

Expand Down

0 comments on commit 83b76f7

Please sign in to comment.