Skip to content

Commit

Permalink
Enable Detectron model inference for CPU and MKL-DNN paths (pytorch#1…
Browse files Browse the repository at this point in the history
…0157)

Summary:
1. Support ops needed for inference of Faster-RCNN/Mask-RCNN needed in Detectron, mostly direct fallbacks.
2. Use CPU device to hold 0-dim tensors and integer tensors in both fallback op and blob feeder, needed by Detectron models.
3. Ignore 0-dim tensor in MKL-DNN concat operator.
4. Generate dynamic library of Detectron module for CPU device.

This PR obsoletes pytorch#9164.
Pull Request resolved: pytorch#10157

Differential Revision: D9276837

Pulled By: yinghai

fbshipit-source-id: dc364932ae4a2e7fcefdee70b5fce3c0cee91b6f
  • Loading branch information
jgong5 authored and facebook-github-bot committed Aug 29, 2018
1 parent 89834df commit c755616
Show file tree
Hide file tree
Showing 9 changed files with 299 additions and 93 deletions.
12 changes: 10 additions & 2 deletions caffe2/ideep/operators/concat_split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,21 @@ class IDEEPConcatOp final : public IDEEPOperator {
virtual ~IDEEPConcatOp() {}

bool RunOnDevice() override {
const auto& input_zero = Input(INPUT0);
auto* output = Output(OUTPUT);
TensorCPU* axis_info = OperatorBase::Output<TensorCPU>(AXIS_INFO, CPU);

vector<itensor> inputs;
for (int i = 0; i < InputSize(); ++i) {
inputs.emplace_back(Input(i));
if (OperatorBase::InputBlob(i).template IsType<itensor>()) {
inputs.emplace_back(Input(i));
} else {
CAFFE_ENFORCE(OperatorBase::InputBlob(i).IsType<Tensor>(CPU),
"Expect cpu tensor if not itensor");
auto& tensor_cpu = OperatorBase::Input<Tensor>(i, CPU);
CAFFE_ENFORCE(tensor_cpu.dims().size() == 0 ||
tensor_cpu.size_from_dim(0) == 0,
"Expect zero dim tensor");
}
}

auto axis_vdata = ideep::concat::compute(inputs, axis_, add_axis_, *output);
Expand Down
8 changes: 8 additions & 0 deletions caffe2/ideep/operators/operator_fallback_ideep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include <caffe2/operators/tanh_op.h>
#include <caffe2/operators/transpose_op.h>
#include <caffe2/operators/utility_ops.h>
#include <caffe2/operators/affine_channel_op.h>
#include <caffe2/operators/stop_gradient.h>
#include <caffe2/sgd/adam_op.h>
#include <caffe2/sgd/iter_op.h>
#include <caffe2/sgd/learning_rate_op.h>
Expand Down Expand Up @@ -116,6 +118,12 @@ REGISTER_IDEEP_OPERATOR(
REGISTER_IDEEP_OPERATOR(
BBoxTransform,
IDEEPFallbackOp<BBoxTransformOp<float, CPUContext>>);
REGISTER_IDEEP_OPERATOR(
AffineChannel,
IDEEPFallbackOp<AffineChannelOp<float, CPUContext>>);
REGISTER_IDEEP_OPERATOR(
StopGradient,
IDEEPFallbackOp<StopGradientOp<CPUContext>>);

REGISTER_IDEEP_OPERATOR(
PadImage,
Expand Down
63 changes: 35 additions & 28 deletions caffe2/ideep/operators/operator_fallback_ideep.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,54 +53,59 @@ class IDEEPFallbackOp final : public IDEEPOperator {
// then forward output blobs to local workspace.
std::unordered_map<string, string> forwarded_output_blobs;
for (int i = 0; i < base_def_.output_size(); i++) {
// For in-place case, the in/output tensor for local_ws must be
// re-created, instead of forwarding from current workspace.
string parent_name(base_def_.output(i));
if (!SkipOutputCopy::Contains(i)) {
parent_name += "_cpu_output_blob_" + base_def_.type();
}
local_output_blobs_.push_back(ws->CreateBlob(parent_name));
CHECK_NOTNULL(local_output_blobs_.back());
forwarded_output_blobs[base_def_.output(i)] = parent_name;
output_inplace_.push_back(false);
for (const string &input_name : base_def_.input()) {
if (input_name == base_def_.output(i)) {
output_inplace_[i] = true;
break;
}
}
}
local_ws_.reset(new Workspace(ws, forwarded_output_blobs));
// Set up the symbols for the local workspace.
for (const string& name : base_def_.input()) {
local_input_blobs_.push_back(local_ws_->CreateBlob(name));
CHECK_NOTNULL(local_input_blobs_.back());
}
input_share_.resize(local_input_blobs_.size(), false);
base_op_.reset(new CPUOp(base_def_, local_ws_.get()));
}

bool RunOnDevice() override {
for (int i = 0; i < InputSize(); ++i) {
if (InputIsType<itensor>(i) && Input(i).get_data_type() == itensor::data_type::f32) {
if (InputIsType<itensor>(i) &&
Input(i).get_data_type() == itensor::data_type::f32) {
auto& input = Input(i);
auto dtensor = local_input_blobs_[i]->GetMutableTensor(CPU);
dtensor->Resize(input.get_dims());
if (input.is_public_format()) {
dtensor->ShareExternalPointer(static_cast<float*>(input.get_data_handle()));
} else {
input.reorder_to(dtensor->template mutable_data<float>());
if (input_share_[i]) {
local_input_blobs_[i]->Reset();
}
} else if (
InputIsType<itensor>(i) &&
Input(i).get_data_type() == itensor::data_type::s32) {
auto& input = Input(i);
input_share_[i] = false;
auto dtensor = local_input_blobs_[i]->GetMutableTensor(CPU);
dtensor->Resize(input.get_dims());
if (input.is_public_format()) {
dtensor->ShareExternalPointer(
static_cast<long*>(input.get_data_handle()));
static_cast<float*>(input.get_data_handle()));
} else {
input.reorder_to(dtensor->template mutable_data<long>());
input.reorder_to(dtensor->template mutable_data<float>());
}
} else {
VLOG(1) << "Input " << i << " is not ideep::tensor. Skipping copy.";
// Note(jiayq): This removes a const but conceptually
// local_input_blobs will only be used as const blob input for the
// base op so we are still fine.
local_input_blobs_[i]->ShareExternal(
const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()),
const_cast<void *>(OperatorBase::Inputs()[i]->GetRaw()),
OperatorBase::Inputs()[i]->meta());
input_share_[i] = true;
}
}

Expand All @@ -120,21 +125,16 @@ class IDEEPFallbackOp final : public IDEEPOperator {
"IDEEP fallback op currently does not support non-TensorCPU "
"output type who needs copying.");
const auto& src = local_output_blobs_[i]->template Get<TensorCPU>();

auto src_dims = src.dims();
if (src.ndim() == 0) {
VLOG(1) << "Copy output: index " << i << " skipped.";
if (src.template IsType<float>() &&
src.dims().size() != 0 && src.size_from_dim(0) != 0 &&
base_op_->type() != "Python") {
Blob* dst = OperatorBase::OutputBlob(i);
dst->Reset(new Tensor(CPU));
auto dtensor = dst->GetMutableTensor(CPU);
dtensor->Resize(src_dims);
dtensor->ShareData(src);
continue;
}

if (src.template IsType<float>()) {
Blob* dst = OperatorBase::OutputBlob(i);
if (!dst->template IsType<itensor>()) {
// The output tensor must be ideep tensor with public format.
// If reusing ideep tensor with non-public format, the tensor buffer
// will be interpreted incorrectly.
if (!dst->template IsType<itensor>() ||
!dst->template Get<itensor>().is_public_format()) {
dst->Reset(new itensor());
}

Expand All @@ -143,7 +143,12 @@ class IDEEPFallbackOp final : public IDEEPOperator {
if (dtensor->get_dims() != dst_dims) {
dtensor->resize(dst_dims, itensor::data_type::f32);
}
dtensor->set_data_handle(const_cast<void*>(src.raw_data()));
if (output_inplace_[i]) {
dtensor->reorder_from(dst_dims, itensor::data_type::f32,
const_cast<void*>(src.raw_data()));
} else {
dtensor->set_data_handle(const_cast<void *>(src.raw_data()));
}
} else {
VLOG(2) << "Output " << base_def_.output(i) << " as CPUTensor";
Blob* dst = OperatorBase::OutputBlob(i);
Expand All @@ -159,6 +164,8 @@ class IDEEPFallbackOp final : public IDEEPOperator {
protected:
vector<Blob*> local_input_blobs_;
vector<Blob*> local_output_blobs_;
vector<bool> output_inplace_;
vector<bool> input_share_;
std::unique_ptr<CPUOp> base_op_;
std::unique_ptr<Workspace> local_ws_;
OperatorDef base_def_;
Expand Down
99 changes: 99 additions & 0 deletions caffe2/python/ideep/operator_fallback_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import unittest
import hypothesis.strategies as st
from hypothesis import given
import numpy as np
from caffe2.python import core, workspace
from caffe2.proto import caffe2_pb2
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.ideep_test_util as mu


@unittest.skipIf(not workspace.C.use_ideep, "No IDEEP support.")
class TestFallbackOps(hu.HypothesisTestCase):
@given(stride=st.integers(1, 3),
pad=st.integers(0, 3),
kernel=st.integers(3, 5),
size=st.integers(8, 10),
input_channels=st.integers(1, 3),
output_channels=st.integers(1, 5),
batch_size=st.integers(1, 3),
use_bias=st.booleans(),
**mu.gcs)
def test_in_place(self, stride, pad, kernel, size,
input_channels, output_channels,
batch_size, use_bias, gc, dc):
# To expose fallback in-place potential issue, the fallback op
# following ideep op must be run at least two iterations.
conv = core.CreateOperator(
"Conv",
["X", "w", "b"] if use_bias else ["X", "w"],
["Y"],
stride=stride,
pad=pad,
kernel=kernel,
device_option=dc[0]
)
X = np.random.rand(
batch_size, input_channels, size, size).astype(np.float32) - 0.5
w = np.random.rand(output_channels, input_channels, kernel, kernel) \
.astype(np.float32) - 0.5
b = np.random.rand(output_channels).astype(np.float32) - 0.5

old_ws_name = workspace.CurrentWorkspace()
workspace.SwitchWorkspace("_device_check_", True)
workspace.FeedBlob('X', X, dc[0])
workspace.FeedBlob('w', w, dc[0])
workspace.FeedBlob('b', b, dc[0])
workspace.RunOperatorOnce(conv)
Y = workspace.FetchBlob('Y')

scale = np.random.randn(Y.shape[1]).astype(np.float32)
bias = np.random.randn(Y.shape[1]).astype(np.float32)
ac = core.CreateOperator(
"AffineChannel",
["Y", "scale", "bias"],
["Y"],
is_learnable=False,
device_option=dc[0]
)
workspace.FeedBlob('scale', scale, dc[0])
workspace.FeedBlob('bias', bias, dc[0])
workspace.RunOperatorOnce(ac)
workspace.RunOperatorOnce(conv)
workspace.RunOperatorOnce(ac)
Y0 = workspace.FetchBlob('Y')

workspace.ResetWorkspace()
dev_net = caffe2_pb2.NetDef()
conv_dev = caffe2_pb2.OperatorDef()
conv_dev.CopyFrom(conv)
conv_dev.device_option.CopyFrom(dc[1])
ac_dev = caffe2_pb2.OperatorDef()
ac_dev.CopyFrom(ac)
ac_dev.device_option.CopyFrom(dc[1])
dev_net.op.extend([conv_dev, ac_dev])
workspace.FeedBlob('X', X, dc[1])
workspace.FeedBlob('w', w, dc[1])
workspace.FeedBlob('b', b, dc[1])
workspace.FeedBlob('scale', scale, dc[1])
workspace.FeedBlob('bias', bias, dc[1])
workspace.RunNetOnce(dev_net)
workspace.RunNetOnce(dev_net)
Y1 = workspace.FetchBlob('Y')

if not np.allclose(Y0, Y1, atol=0.01, rtol=0.01):
print(Y1.flatten())
print(Y0.flatten())
print(np.max(np.abs(Y1 - Y0)))
self.assertTrue(False)

workspace.SwitchWorkspace(old_ws_name)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit c755616

Please sign in to comment.