Skip to content

Commit

Permalink
Use ATen implementation of RNNs (pytorch#10761)
Browse files Browse the repository at this point in the history
Summary:
apaszke recently ported RNNs from Python into ATen, which means we can replace our implementation in the C++ API (written by ebetica) with the ATen implementation, which cleans up a lot of code (+99, -323). Thanks apaszke!

I also added the `bidirectional` and `batch_first` options to the C++ API RNN options, just because why not.

apaszke ebetica
Pull Request resolved: pytorch#10761

Differential Revision: D9443885

Pulled By: goldsborough

fbshipit-source-id: b6ef7566b9ced2b2f0b2e1f46c295b6f250c65a8
  • Loading branch information
goldsborough authored and facebook-github-bot committed Aug 23, 2018
1 parent a4c59a9 commit 9403e0c
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 375 deletions.
144 changes: 73 additions & 71 deletions test/cpp/api/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <test/cpp/api/util.h>

using Catch::StartsWith;

using namespace torch::nn;
using namespace torch::test;

Expand Down Expand Up @@ -84,99 +86,99 @@ void check_lstm_sizes(RNNOutput output) {
REQUIRE(output.state.norm().toCFloat() > 0);
}

TEST_CASE("rnn") {
TEST_CASE("RNN/CheckOutputSizes") {
torch::manual_seed(0);
SECTION("sizes") {
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
auto x = torch::randn({10, 16, 128}, torch::requires_grad());
auto output = model->forward(x);
auto y = x.mean();
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
// Input size is: sequence length, batch size, input size
auto x = torch::randn({10, 16, 128}, torch::requires_grad());
auto output = model->forward(x);
auto y = x.mean();

y.backward();
check_lstm_sizes(output);
y.backward();
check_lstm_sizes(output);

auto next = model->forward(x, output.state);
auto next = model->forward(x, output.state);

check_lstm_sizes(next);
check_lstm_sizes(next);

torch::Tensor diff = next.state - output.state;
torch::Tensor diff = next.state - output.state;

// Hiddens changed
REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
}

SECTION("outputs") {
// Make sure the outputs match pytorch outputs
LSTM model(2, 2);
for (auto& v : model->parameters()) {
float size = v->numel();
auto p = static_cast<float*>(v->storage().data());
for (size_t i = 0; i < size; i++) {
p[i] = i / size;
}
}
// Hiddens changed
REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
}

auto x = torch::empty({3, 4, 2}, torch::requires_grad());
float size = x.numel();
auto p = static_cast<float*>(x.storage().data());
TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
torch::manual_seed(0);
// Make sure the outputs match pytorch outputs
LSTM model(2, 2);
for (auto& v : model->parameters()) {
float size = v->numel();
auto p = static_cast<float*>(v->storage().data());
for (size_t i = 0; i < size; i++) {
p[i] = (size - i) / size;
p[i] = i / size;
}
}

auto out = model->forward(x);
REQUIRE(out.output.ndimension() == 3);
REQUIRE(out.output.size(0) == 3);
REQUIRE(out.output.size(1) == 4);
REQUIRE(out.output.size(2) == 2);

auto flat = out.output.view(3 * 4 * 2);
float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968,
0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
for (size_t i = 0; i < 3 * 4 * 2; i++) {
REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
}
auto x = torch::empty({3, 4, 2}, torch::requires_grad());
float size = x.numel();
auto p = static_cast<float*>(x.storage().data());
for (size_t i = 0; i < size; i++) {
p[i] = (size - i) / size;
}

REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
REQUIRE(out.state.size(0) == 2);
REQUIRE(out.state.size(1) == 1);
REQUIRE(out.state.size(2) == 4);
REQUIRE(out.state.size(3) == 2);
flat = out.state.view(16);
float h_out[] = {0.7889,
0.9003,
0.7769,
0.8905,
0.7635,
0.8794,
0.7484,
0.8666,
1.1647,
1.6106,
1.1425,
1.5726,
1.1187,
1.5329,
1.0931,
1.4911};
for (size_t i = 0; i < 16; i++) {
REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
}
auto out = model->forward(x);
REQUIRE(out.output.ndimension() == 3);
REQUIRE(out.output.size(0) == 3);
REQUIRE(out.output.size(1) == 4);
REQUIRE(out.output.size(2) == 2);

auto flat = out.output.view(3 * 4 * 2);
float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968,
0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
for (size_t i = 0; i < 3 * 4 * 2; i++) {
REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
}

REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
REQUIRE(out.state.size(0) == 2);
REQUIRE(out.state.size(1) == 1);
REQUIRE(out.state.size(2) == 4);
REQUIRE(out.state.size(3) == 2);
flat = out.state.view(16);
float h_out[] = {0.7889,
0.9003,
0.7769,
0.8905,
0.7635,
0.8794,
0.7484,
0.8666,
1.1647,
1.6106,
1.1425,
1.5726,
1.1187,
1.5329,
1.0931,
1.4911};
for (size_t i = 0; i < 16; i++) {
REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
}
}

TEST_CASE("rnn/integration/LSTM") {
TEST_CASE("RNN/integration/LSTM") {
REQUIRE(test_RNN_xor<LSTM>(
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }));
}

TEST_CASE("rnn/integration/GRU") {
TEST_CASE("RNN/integration/GRU") {
REQUIRE(
test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).layers(2)); }));
}

TEST_CASE("rnn/integration/RNN") {
TEST_CASE("RNN/integration/RNN") {
SECTION("relu") {
REQUIRE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }));
Expand Down
109 changes: 66 additions & 43 deletions torch/csrc/api/include/torch/nn/modules/rnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ struct RNNOptionsBase {
TORCH_ARG(int64_t, layers) = 1;
TORCH_ARG(bool, with_bias) = true;
TORCH_ARG(double, dropout) = 0.0;
TORCH_ARG(bool, bidirectional) = false;
TORCH_ARG(bool, batch_first) = false;
};

template <typename Derived>
Expand All @@ -40,69 +42,83 @@ class RNNImplBase : public torch::nn::Cloneable<Derived> {
// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };

RNNImplBase(
explicit RNNImplBase(
RNNOptionsBase options_,
at::optional<CuDNNMode> cudnn_mode = at::nullopt,
int64_t number_of_gates = 1,
bool has_cell_state = false);

RNNOutput forward(Tensor input, Tensor state = {});
int64_t number_of_gates = 1);

/// Initializes the parameters of the RNN module.
void reset() override;

/// Recursively casts all parameters to the given device and dtype.
/// Overrides `nn::Module::to()` to call `flatten_parameters()` after the
/// original operation.
void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)
override;

/// Recursively casts all parameters to the given dtype.
void to(torch::Dtype dtype, bool non_blocking = false) override;

/// Recursively moves all parameters to the given device.
void to(torch::Device device, bool non_blocking = false) override;

/// Fills the internal flattened parameter buffers passed to cuDNN. Call this
/// method if you mess around with the variable storages and want to use
/// cuDNN.
void flatten_parameters_for_cudnn();

/// Modifies the internal storage of weights for optimization purposes.
///
/// On CPU, this method should be called if any of the weight or bias vectors
/// are changed (i.e. weights are added or removed). On GPU, it should be
/// called __any time the storage of any parameter is modified__, e.g. any
/// time a parameter is assigned a new value. This allows using the fast path
/// in cuDNN implementations of respective RNN `forward()` methods. It is
/// called once upon construction, inside `reset()`.
void flatten_parameters();

/// The RNN's options.
RNNOptionsBase options;

/// The weights for `input x hidden` gates.
std::vector<Tensor> w_ih;
/// The weights for `hidden x hidden` gates.
std::vector<Tensor> w_hh;
/// The biases for `input x hidden` gates.
std::vector<Tensor> b_ih;
/// The biases for `hidden x hidden` gates.
std::vector<Tensor> b_hh;

Dropout dropout;

protected:
virtual Tensor cell_forward(Tensor input, Tensor state, int64_t layer) = 0;

RNNOutput CUDNN_forward(Tensor input, Tensor state);
RNNOutput autograd_forward(Tensor input, Tensor state);

/// The function signature of `at::rnn_relu`, `at::rnn_tanh` and `at::gru`.
using RNNFunctionSignature = std::tuple<Tensor, Tensor>(
/*input=*/const Tensor&,
/*state=*/const Tensor&,
/*params=*/TensorList,
/*has_biases=*/bool,
/*layers=*/int64_t,
/*dropout=*/double,
/*train=*/bool,
/*bidirectional=*/bool,
/*batch_first=*/bool);

/// A generic `forward()` used for RNN and GRU (but not LSTM!). Takes the ATen
/// RNN function as first argument.
RNNOutput generic_forward(
std::function<RNNFunctionSignature> function,
Tensor input,
Tensor state);

/// Returns a flat vector of all weights, with layer weights following each
/// other sequentially in (w_ih, w_hh, b_ih, b_hh) order.
std::vector<Tensor> flat_weights() const;
bool use_cudnn(Tensor sample) const;
Tensor create_dropout_state(Tensor input) const;

/// Very simple check if any of the parameters (weights, biases) are the same.
bool any_parameters_alias() const;

/// The number of gate weights/biases required by the RNN subclass.
int64_t number_of_gates_;
bool has_cell_state_;

/// The cuDNN RNN mode, if this RNN subclass has any.
at::optional<CuDNNMode> cudnn_mode_;

// This is copied from pytorch, to determine whether weights are flat for the
// fast CUDNN route. Otherwise, we have to use non flattened weights, which
// are much slower.
// https://github.com/pytorch/pytorch/blob/1848cad10802db9fa0aa066d9de195958120d863/torch/nn/modules/rnn.py#L159-L165
// TODO Actually since we are in C++ we can probably just actually check if
// the parameters are flat, instead of relying on data pointers and stuff.
std::vector<void*> data_ptrs_;
Tensor flat_weights_;
/// The cached result of the latest `flat_weights()` call.
std::vector<Tensor> flat_weights_;
};
} // namespace detail

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// TODO: Replace this with passing an activation module.

enum class RNNActivation { ReLU, Tanh };

struct RNNOptions {
Expand All @@ -116,6 +132,8 @@ struct RNNOptions {
TORCH_ARG(int64_t, layers) = 1;
TORCH_ARG(bool, with_bias) = true;
TORCH_ARG(double, dropout) = 0.0;
TORCH_ARG(bool, bidirectional) = false;
TORCH_ARG(bool, batch_first) = false;
TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU;
};

Expand All @@ -125,13 +143,14 @@ class RNNImpl : public detail::RNNImplBase<RNNImpl> {
: RNNImpl(RNNOptions(input_size, hidden_size)) {}
explicit RNNImpl(RNNOptions options);

RNNOptions options;
RNNOutput forward(Tensor input, Tensor state = {});

private:
Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override;
std::function<Tensor(Tensor)> activation_function_;
RNNOptions options;
};

/// A multi-layer Elman RNN module with Tanh or ReLU activation.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN for more
/// documenation.
TORCH_MODULE(RNN);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -144,10 +163,12 @@ class LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
: LSTMImpl(LSTMOptions(input_size, hidden_size)) {}
explicit LSTMImpl(LSTMOptions options);

private:
Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override;
RNNOutput forward(Tensor input, Tensor state = {});
};

/// A multi-layer long-short-term-memory (LSTM) module.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM for more
/// documenation.
TORCH_MODULE(LSTM);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -160,10 +181,12 @@ class GRUImpl : public detail::RNNImplBase<GRUImpl> {
: GRUImpl(GRUOptions(input_size, hidden_size)) {}
explicit GRUImpl(GRUOptions options);

private:
Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override;
RNNOutput forward(Tensor input, Tensor state = {});
};

/// A multi-layer gated recurrent unit (GRU) module.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU for more
/// documenation.
TORCH_MODULE(GRU);

} // namespace nn
Expand Down
Loading

0 comments on commit 9403e0c

Please sign in to comment.