Skip to content

Commit

Permalink
Replace thpp::Tensor with ATen Tensor in autograd csrc (pytorch#2170)
Browse files Browse the repository at this point in the history
  • Loading branch information
killeent authored and apaszke committed Jul 28, 2017
1 parent f1fd4ac commit c304d04
Show file tree
Hide file tree
Showing 33 changed files with 534 additions and 441 deletions.
15 changes: 7 additions & 8 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,22 @@ cache:
install:
- unset CCACHE_DISABLE
- export CCACHE_DIR=$HOME/.ccache
- export CC="ccache gcc-4.8"
- export CXX="ccache g++-4.8"
- export CC="ccache gcc-5"
- export CXX="ccache g++-5"
- ccache --show-stats
- travis_retry pip install --upgrade pip setuptools wheel
- travis_retry pip install -r requirements.txt --only-binary=scipy
- python setup.py install

script:
- OMP_NUM_THREADS=2 ./test/run_test.sh
- MAX_JOBS=8 python setup.py install

addons:
apt:
sources:
- ubuntu-toolchain-r-test
packages:
- gcc-4.8
- g++-4.8
- g++-5

script:
- OMP_NUM_THREADS=2 ./test/run_test.sh

# This reportedly works around an issue downloading packages from pypi on
# travis. Consider removing this after the underlying issue is fixed.
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def _single_compile(obj):
src, ext = build[obj]
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
num_jobs = multiprocessing.cpu_count()
max_jobs = os.getenv("MAX_JOBS")
if max_jobs is not None:
num_jobs = min(num_jobs, int(max_jobs))
multiprocessing.pool.ThreadPool(num_jobs).map(_single_compile, objects)

return objects
Expand Down
7 changes: 4 additions & 3 deletions torch/csrc/DynamicTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,15 @@ at::Tensor createTensorAT(PyObject *data)
{
auto tensor_type = pytype_to_attype.at(Py_TYPE(data));
auto tensor = ((THPVoidTensor *)data)->cdata;
return tensor_type->unsafeTensorFromTH(tensor, true);
return tensor_type->unsafeTensorFromTH(tensor, true); // Calls retain on underlying TH Tensor
}
PyObject* createPyObject(at::Tensor tensor)
PyObject* createPyObject(at::Tensor& tensor)
{
auto type = getPyTypeObject(tensor);
PyObject *obj = type->tp_alloc(type, 0);
if (obj) {
((THPVoidTensor*)obj)->cdata = (THVoidTensor *)tensor.detach()->unsafeGetTH(true);
// Retain underlying TH Tensor
((THPVoidTensor*)obj)->cdata = (THVoidTensor *)tensor.unsafeGetTH(true);
}
return obj;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/DynamicTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ std::unique_ptr<thpp::Tensor> createTensor(PyObject *data);
// Creates Python tensor object from a Tensor
PyObject* createPyObject(const thpp::Tensor& tensor);

PyObject* createPyObject(at::Tensor tensor);
PyObject* createPyObject(at::Tensor& tensor);
PyTypeObject* getPyTypeObject(const at::Tensor& tensor);
//rename to createPyObject when THPP is removed
at::Tensor createTensorAT(PyObject *data);
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <Python.h>
#include "torch/csrc/autograd/function_hook.h"

#include <THPP/THPP.h>
#include <ATen/ATen.h>

#include <memory>
#include <vector>
Expand All @@ -19,7 +19,7 @@ namespace torch { namespace autograd {
struct Function;
struct Variable;

using tensor_list = std::vector<std::unique_ptr<thpp::Tensor>>;
using tensor_list = std::vector<at::Tensor>;
using variable_list = std::vector<std::shared_ptr<Variable>>;
using function_list = std::vector<std::pair<std::shared_ptr<Function>, int>>;

Expand Down
14 changes: 5 additions & 9 deletions torch/csrc/autograd/functions/accumulate_grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@ auto AccumulateGrad::acc_inplace(std::shared_ptr<Variable>& grad,
std::shared_ptr<Variable>& new_grad) -> void {
auto& grad_data = grad->data;
auto& new_grad_data = new_grad->data;
AutoGPU guard(grad_data->getDevice());
AutoGPU guard(grad_data.type().isCuda() ? grad_data.get_device() : -1);

// The grad may need a promotion from a sparse to dense type
if (grad_data->isSparse() && !new_grad_data->isSparse()) {
std::unique_ptr<thpp::Tensor> result = new_grad_data->newTensor();
result->cadd(*new_grad_data, *grad_data);
grad->data = std::move(result);
if (grad_data.type().isSparse() && !new_grad_data.type().isSparse()) {
grad->data = new_grad_data + grad_data;
} else {
grad_data->cadd(*grad_data, *new_grad_data);
grad_data += new_grad_data;
}
}

Expand Down Expand Up @@ -80,8 +77,7 @@ auto AccumulateGrad::apply(const variable_list& grads) -> variable_list {
} else {
// Once the grad becomes not volatile, it should stay like that
if (!var->grad->is_volatile && new_grad->is_volatile) {
new_grad = std::make_shared<Variable>(
std::unique_ptr<thpp::Tensor>(new_grad->data->clone_shallow()), false, false);
new_grad = std::make_shared<Variable>(new_grad->data, false, false);
}
var->grad = Add().apply({var->grad, new_grad})[0];
}
Expand Down
14 changes: 6 additions & 8 deletions torch/csrc/autograd/functions/basic_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ auto DelayedError::apply(const variable_list& inputs) -> variable_list {
tensor_list outputs;
outputs.reserve(inputs.size());
for (auto& var : inputs) {
outputs.emplace_back(var ? var->data->clone_shallow() : nullptr);
outputs.emplace_back(var ? var->data : at::Tensor());
}
return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) {
return std::make_shared<Error>(msg, std::move(f));
Expand All @@ -25,16 +25,14 @@ auto Add::apply(const variable_list& inputs) -> variable_list {
check_input_variables("Add", inputs, 2);
auto& input1 = inputs[0]->data;
auto& input2 = inputs[1]->data;
AutoGPU guard(input1->getDevice());
AutoGPU guard(input1.type().isCuda() ? input1.get_device() : -1);

bool first_sparse = input1->isSparse();
auto output = first_sparse ? input2->newTensor() : input1->newTensor();
if (first_sparse) {
output->cadd(*input2, *input1);
at::Tensor output;
if (input1.type().isSparse()) {
output = input2 + input1;
} else {
output->cadd(*input1, *input2);
output = input1 + input2;
}

return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
return std::make_shared<AddBackward>(std::move(f));
});
Expand Down
148 changes: 70 additions & 78 deletions torch/csrc/autograd/functions/batch_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,74 +30,69 @@ namespace {

namespace torch { namespace autograd {

using thpp::Tensor;

#ifndef CUDNN_BN_MIN_EPSILON
#define CUDNN_BN_MIN_EPSILON 0
#endif

auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
check_input_variables("BatchNorm", inputs, 3, 1);

auto& input = inputs[0];
auto& weight = inputs[1];
auto& bias = inputs[2];
AutoGPU guard(input->data->getDevice());
auto num_features = input->data->rawSizes()[1];
check_dims_match_num_input_features("running_mean", num_features, running_mean->numel());
check_dims_match_num_input_features("running_var", num_features, running_var->numel());
AutoGPU guard(input->data.type().isCuda() ? input->data.get_device() : -1);

auto num_features = input->data.sizes()[1];
check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
check_dims_match_num_input_features("running_var", num_features, running_var.numel());
if (weight){
check_dims_match_num_input_features("weight", num_features, weight->data->numel());
check_dims_match_num_input_features("weight", num_features, weight->data.numel());
}
if (bias){
check_dims_match_num_input_features("bias", num_features, bias->data->numel());
check_dims_match_num_input_features("bias", num_features, bias->data.numel());
}

bool use_cudnn = false;
#ifdef WITH_CUDNN
use_cudnn = (input->data->isCuda()
&& input->data->type() != thpp::Type::HALF
use_cudnn = (input->data.type().isCuda()
&& input->data.type().scalarType() != at::kHalf
&& weight && bias
&& cudnn_enabled && CUDNN_VERSION >= 5110L);
#endif

auto output = input->data->newTensor();
output->resizeAs(*input->data);

std::unique_ptr<Tensor> save_mean(output->newTensor());
save_mean->resizeAs(*running_mean);
std::unique_ptr<Tensor> save_std(output->newTensor());
save_std->resizeAs(*running_var);
auto output = input->data.type().tensor(input->data.sizes());
auto save_mean = running_mean.type().tensor(running_mean.sizes());
auto save_std = running_var.type().tensor(running_var.sizes());

if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
#ifdef WITH_CUDNN
torch::cudnn::cudnn_batch_norm_forward(
state,
torch::cudnn::getCudnnHandle(),
torch::cudnn::getCudnnDataType(*input->data),
(THVoidTensor*)input->data->cdata(),
(THVoidTensor*)output->cdata(),
(THVoidTensor*)weight->data->cdata(),
(THVoidTensor*)bias->data->cdata(),
(THVoidTensor*)running_mean->cdata(),
(THVoidTensor*)running_var->cdata(),
(THVoidTensor*)save_mean->cdata(),
(THVoidTensor*)save_std->cdata(),
torch::cudnn::getCudnnDataType(input->data),
(THVoidTensor*)input->data.unsafeGetTH(false),
(THVoidTensor*)output.unsafeGetTH(false),
(THVoidTensor*)weight->data.unsafeGetTH(false),
(THVoidTensor*)bias->data.unsafeGetTH(false),
(THVoidTensor*)running_mean.unsafeGetTH(false),
(THVoidTensor*)running_var.unsafeGetTH(false),
(THVoidTensor*)save_mean.unsafeGetTH(false),
(THVoidTensor*)save_std.unsafeGetTH(false),
training,
momentum,
eps);
#endif
} else {
torch::nn::BatchNormalization_updateOutput(
input->data.get(),
output.get(),
weight ? weight->data.get() : nullptr,
bias ? bias->data.get() : nullptr,
running_mean.get(),
running_var.get(),
save_mean.get(),
save_std.get(),
at::Tensor nt;
at::BatchNormalization_updateOutput(
input->data,
output,
weight ? weight->data : nt,
bias ? bias->data : nt,
running_mean,
running_var,
save_mean,
save_std,
training,
momentum,
eps);
Expand All @@ -119,77 +114,74 @@ auto BatchNormBackward::apply(const variable_list& grad_outputs) -> variable_lis
auto weight_var = this->weight.unpack();
auto bias_var = this->bias.unpack();

std::unique_ptr<thpp::Tensor> input {input_var->data->clone_shallow()};
std::unique_ptr<thpp::Tensor> weight {weight_var ? weight_var->data->clone_shallow() : nullptr};
std::unique_ptr<thpp::Tensor> bias {bias_var ? bias_var->data->clone_shallow() : nullptr};
auto input = input_var->data;
auto weight = weight_var ? weight_var->data : at::Tensor();
auto bias = bias_var ? bias_var->data : at::Tensor();

AutoGPU guard(input->getDevice());
AutoGPU guard(input.type().isCuda() ? input.get_device() : -1);

bool use_cudnn = false;
#ifdef WITH_CUDNN
use_cudnn = (input->isCuda()
&& input->type() != thpp::Type::HALF
&& weight && bias && training
use_cudnn = (input.type().backend() == at::kCUDA
&& input.type().scalarType() != at::kHalf
&& weight.defined() && bias.defined() && training
&& cudnn_enabled && CUDNN_VERSION >= 5110L);
#endif

std::unique_ptr<Tensor> grad_input;
at::Tensor grad_input;
if (should_compute_output(0) || use_cudnn) {
grad_input = input->newTensor();
grad_input->resizeAs(*input);
grad_input = input.type().tensor(input.sizes());
}

std::unique_ptr<Tensor> grad_weight;
at::Tensor grad_weight;
if (should_compute_output(1) || use_cudnn) {
grad_weight = weight->newTensor();
grad_weight->resizeAs(*weight);
grad_weight = weight.type().tensor(weight.sizes());
if (!use_cudnn) {
grad_weight->zero();
grad_weight.zero_();
}
}

std::unique_ptr<Tensor> grad_bias;
at::Tensor grad_bias;
if (should_compute_output(2) || use_cudnn) {
grad_bias = bias->newTensor();
grad_bias->resizeAs(*bias);
grad_bias = bias.type().tensor(bias.sizes());
if (!use_cudnn) {
grad_bias->zero();
grad_bias.zero_();
}
}

auto grad_output = grad_outputs[0]->data->contiguous();
auto grad_output = grad_outputs[0]->data.contiguous();

if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
#ifdef WITH_CUDNN
torch::cudnn::cudnn_batch_norm_backward(
state,
torch::cudnn::getCudnnHandle(),
torch::cudnn::getCudnnDataType(*input),
(THVoidTensor*)input->cdata(),
(THVoidTensor*)grad_output->cdata(),
(THVoidTensor*)grad_input->cdata(),
(THVoidTensor*)grad_weight->cdata(),
(THVoidTensor*)grad_bias->cdata(),
(THVoidTensor*)weight->cdata(),
(THVoidTensor*)running_mean->cdata(),
(THVoidTensor*)running_var->cdata(),
(THVoidTensor*)save_mean->cdata(),
(THVoidTensor*)save_std->cdata(),
torch::cudnn::getCudnnDataType(input),
(THVoidTensor*)input.unsafeGetTH(false),
(THVoidTensor*)grad_output.unsafeGetTH(false),
(THVoidTensor*)grad_input.unsafeGetTH(false),
(THVoidTensor*)grad_weight.unsafeGetTH(false),
(THVoidTensor*)grad_bias.unsafeGetTH(false),
(THVoidTensor*)weight.unsafeGetTH(false),
(THVoidTensor*)running_mean.unsafeGetTH(false),
(THVoidTensor*)running_var.unsafeGetTH(false),
(THVoidTensor*)save_mean.unsafeGetTH(false),
(THVoidTensor*)save_std.unsafeGetTH(false),
training,
eps);
#endif
} else {
torch::nn::BatchNormalization_backward(
input.get(),
grad_output.get(),
grad_input.get(),
grad_weight.get(),
grad_bias.get(),
weight.get(),
running_mean.get(),
running_var.get(),
save_mean.get(),
save_std.get(),
at::BatchNormalization_backward(
input,
grad_output,
grad_input,
grad_weight,
grad_bias,
weight,
running_mean,
running_var,
save_mean,
save_std,
training,
1.0,
eps);
Expand Down
Loading

0 comments on commit c304d04

Please sign in to comment.