From 9403e0cac07ebaad463a992e0874baeeb86e355a Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Thu, 23 Aug 2018 16:01:03 -0700 Subject: [PATCH] Use ATen implementation of RNNs (#10761) Summary: apaszke recently ported RNNs from Python into ATen, which means we can replace our implementation in the C++ API (written by ebetica) with the ATen implementation, which cleans up a lot of code (+99, -323). Thanks apaszke! I also added the `bidirectional` and `batch_first` options to the C++ API RNN options, just because why not. apaszke ebetica Pull Request resolved: https://github.com/pytorch/pytorch/pull/10761 Differential Revision: D9443885 Pulled By: goldsborough fbshipit-source-id: b6ef7566b9ced2b2f0b2e1f46c295b6f250c65a8 --- test/cpp/api/rnn.cpp | 144 +++---- torch/csrc/api/include/torch/nn/modules/rnn.h | 109 +++-- torch/csrc/api/src/nn/modules/rnn.cpp | 386 ++++++------------ torch/nn/modules/rnn.py | 2 +- 4 files changed, 266 insertions(+), 375 deletions(-) diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp index 92067bada0f737..96685728484a39 100644 --- a/test/cpp/api/rnn.cpp +++ b/test/cpp/api/rnn.cpp @@ -8,6 +8,8 @@ #include +using Catch::StartsWith; + using namespace torch::nn; using namespace torch::test; @@ -84,99 +86,99 @@ void check_lstm_sizes(RNNOutput output) { REQUIRE(output.state.norm().toCFloat() > 0); } -TEST_CASE("rnn") { +TEST_CASE("RNN/CheckOutputSizes") { torch::manual_seed(0); - SECTION("sizes") { - LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2)); - auto x = torch::randn({10, 16, 128}, torch::requires_grad()); - auto output = model->forward(x); - auto y = x.mean(); + LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2)); + // Input size is: sequence length, batch size, input size + auto x = torch::randn({10, 16, 128}, torch::requires_grad()); + auto output = model->forward(x); + auto y = x.mean(); - y.backward(); - check_lstm_sizes(output); + y.backward(); + check_lstm_sizes(output); - auto next = model->forward(x, output.state); + auto next = model->forward(x, output.state); - check_lstm_sizes(next); + check_lstm_sizes(next); - torch::Tensor diff = next.state - output.state; + torch::Tensor diff = next.state - output.state; - // Hiddens changed - REQUIRE(diff.abs().sum().toCFloat() > 1e-3); - } - - SECTION("outputs") { - // Make sure the outputs match pytorch outputs - LSTM model(2, 2); - for (auto& v : model->parameters()) { - float size = v->numel(); - auto p = static_cast(v->storage().data()); - for (size_t i = 0; i < size; i++) { - p[i] = i / size; - } - } + // Hiddens changed + REQUIRE(diff.abs().sum().toCFloat() > 1e-3); +} - auto x = torch::empty({3, 4, 2}, torch::requires_grad()); - float size = x.numel(); - auto p = static_cast(x.storage().data()); +TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") { + torch::manual_seed(0); + // Make sure the outputs match pytorch outputs + LSTM model(2, 2); + for (auto& v : model->parameters()) { + float size = v->numel(); + auto p = static_cast(v->storage().data()); for (size_t i = 0; i < size; i++) { - p[i] = (size - i) / size; + p[i] = i / size; } + } - auto out = model->forward(x); - REQUIRE(out.output.ndimension() == 3); - REQUIRE(out.output.size(0) == 3); - REQUIRE(out.output.size(1) == 4); - REQUIRE(out.output.size(2) == 2); - - auto flat = out.output.view(3 * 4 * 2); - float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239, - 0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968, - 0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003, - 0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666}; - for (size_t i = 0; i < 3 * 4 * 2; i++) { - REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3); - } + auto x = torch::empty({3, 4, 2}, torch::requires_grad()); + float size = x.numel(); + auto p = static_cast(x.storage().data()); + for (size_t i = 0; i < size; i++) { + p[i] = (size - i) / size; + } - REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2 - REQUIRE(out.state.size(0) == 2); - REQUIRE(out.state.size(1) == 1); - REQUIRE(out.state.size(2) == 4); - REQUIRE(out.state.size(3) == 2); - flat = out.state.view(16); - float h_out[] = {0.7889, - 0.9003, - 0.7769, - 0.8905, - 0.7635, - 0.8794, - 0.7484, - 0.8666, - 1.1647, - 1.6106, - 1.1425, - 1.5726, - 1.1187, - 1.5329, - 1.0931, - 1.4911}; - for (size_t i = 0; i < 16; i++) { - REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3); - } + auto out = model->forward(x); + REQUIRE(out.output.ndimension() == 3); + REQUIRE(out.output.size(0) == 3); + REQUIRE(out.output.size(1) == 4); + REQUIRE(out.output.size(2) == 2); + + auto flat = out.output.view(3 * 4 * 2); + float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239, + 0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968, + 0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003, + 0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666}; + for (size_t i = 0; i < 3 * 4 * 2; i++) { + REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3); + } + + REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2 + REQUIRE(out.state.size(0) == 2); + REQUIRE(out.state.size(1) == 1); + REQUIRE(out.state.size(2) == 4); + REQUIRE(out.state.size(3) == 2); + flat = out.state.view(16); + float h_out[] = {0.7889, + 0.9003, + 0.7769, + 0.8905, + 0.7635, + 0.8794, + 0.7484, + 0.8666, + 1.1647, + 1.6106, + 1.1425, + 1.5726, + 1.1187, + 1.5329, + 1.0931, + 1.4911}; + for (size_t i = 0; i < 16; i++) { + REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3); } } -TEST_CASE("rnn/integration/LSTM") { +TEST_CASE("RNN/integration/LSTM") { REQUIRE(test_RNN_xor( [](int s) { return LSTM(LSTMOptions(s, s).layers(2)); })); } -TEST_CASE("rnn/integration/GRU") { +TEST_CASE("RNN/integration/GRU") { REQUIRE( test_RNN_xor([](int s) { return GRU(GRUOptions(s, s).layers(2)); })); } -TEST_CASE("rnn/integration/RNN") { +TEST_CASE("RNN/integration/RNN") { SECTION("relu") { REQUIRE(test_RNN_xor( [](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); })); diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h index 9011de91b028ee..17ff80b66197c8 100644 --- a/torch/csrc/api/include/torch/nn/modules/rnn.h +++ b/torch/csrc/api/include/torch/nn/modules/rnn.h @@ -31,6 +31,8 @@ struct RNNOptionsBase { TORCH_ARG(int64_t, layers) = 1; TORCH_ARG(bool, with_bias) = true; TORCH_ARG(double, dropout) = 0.0; + TORCH_ARG(bool, bidirectional) = false; + TORCH_ARG(bool, batch_first) = false; }; template @@ -40,69 +42,83 @@ class RNNImplBase : public torch::nn::Cloneable { // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 }; - RNNImplBase( + explicit RNNImplBase( RNNOptionsBase options_, at::optional cudnn_mode = at::nullopt, - int64_t number_of_gates = 1, - bool has_cell_state = false); - - RNNOutput forward(Tensor input, Tensor state = {}); + int64_t number_of_gates = 1); + /// Initializes the parameters of the RNN module. void reset() override; - /// Recursively casts all parameters to the given device and dtype. + /// Overrides `nn::Module::to()` to call `flatten_parameters()` after the + /// original operation. void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false) override; - - /// Recursively casts all parameters to the given dtype. void to(torch::Dtype dtype, bool non_blocking = false) override; - - /// Recursively moves all parameters to the given device. void to(torch::Device device, bool non_blocking = false) override; - /// Fills the internal flattened parameter buffers passed to cuDNN. Call this - /// method if you mess around with the variable storages and want to use - /// cuDNN. - void flatten_parameters_for_cudnn(); - + /// Modifies the internal storage of weights for optimization purposes. + /// + /// On CPU, this method should be called if any of the weight or bias vectors + /// are changed (i.e. weights are added or removed). On GPU, it should be + /// called __any time the storage of any parameter is modified__, e.g. any + /// time a parameter is assigned a new value. This allows using the fast path + /// in cuDNN implementations of respective RNN `forward()` methods. It is + /// called once upon construction, inside `reset()`. + void flatten_parameters(); + + /// The RNN's options. RNNOptionsBase options; + /// The weights for `input x hidden` gates. std::vector w_ih; + /// The weights for `hidden x hidden` gates. std::vector w_hh; + /// The biases for `input x hidden` gates. std::vector b_ih; + /// The biases for `hidden x hidden` gates. std::vector b_hh; - Dropout dropout; - protected: - virtual Tensor cell_forward(Tensor input, Tensor state, int64_t layer) = 0; - - RNNOutput CUDNN_forward(Tensor input, Tensor state); - RNNOutput autograd_forward(Tensor input, Tensor state); - + /// The function signature of `at::rnn_relu`, `at::rnn_tanh` and `at::gru`. + using RNNFunctionSignature = std::tuple( + /*input=*/const Tensor&, + /*state=*/const Tensor&, + /*params=*/TensorList, + /*has_biases=*/bool, + /*layers=*/int64_t, + /*dropout=*/double, + /*train=*/bool, + /*bidirectional=*/bool, + /*batch_first=*/bool); + + /// A generic `forward()` used for RNN and GRU (but not LSTM!). Takes the ATen + /// RNN function as first argument. + RNNOutput generic_forward( + std::function function, + Tensor input, + Tensor state); + + /// Returns a flat vector of all weights, with layer weights following each + /// other sequentially in (w_ih, w_hh, b_ih, b_hh) order. std::vector flat_weights() const; - bool use_cudnn(Tensor sample) const; - Tensor create_dropout_state(Tensor input) const; + /// Very simple check if any of the parameters (weights, biases) are the same. + bool any_parameters_alias() const; + + /// The number of gate weights/biases required by the RNN subclass. int64_t number_of_gates_; - bool has_cell_state_; + + /// The cuDNN RNN mode, if this RNN subclass has any. at::optional cudnn_mode_; - // This is copied from pytorch, to determine whether weights are flat for the - // fast CUDNN route. Otherwise, we have to use non flattened weights, which - // are much slower. - // https://github.com/pytorch/pytorch/blob/1848cad10802db9fa0aa066d9de195958120d863/torch/nn/modules/rnn.py#L159-L165 - // TODO Actually since we are in C++ we can probably just actually check if - // the parameters are flat, instead of relying on data pointers and stuff. - std::vector data_ptrs_; - Tensor flat_weights_; + /// The cached result of the latest `flat_weights()` call. + std::vector flat_weights_; }; } // namespace detail // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// TODO: Replace this with passing an activation module. - enum class RNNActivation { ReLU, Tanh }; struct RNNOptions { @@ -116,6 +132,8 @@ struct RNNOptions { TORCH_ARG(int64_t, layers) = 1; TORCH_ARG(bool, with_bias) = true; TORCH_ARG(double, dropout) = 0.0; + TORCH_ARG(bool, bidirectional) = false; + TORCH_ARG(bool, batch_first) = false; TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU; }; @@ -125,13 +143,14 @@ class RNNImpl : public detail::RNNImplBase { : RNNImpl(RNNOptions(input_size, hidden_size)) {} explicit RNNImpl(RNNOptions options); - RNNOptions options; + RNNOutput forward(Tensor input, Tensor state = {}); - private: - Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override; - std::function activation_function_; + RNNOptions options; }; +/// A multi-layer Elman RNN module with Tanh or ReLU activation. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN for more +/// documenation. TORCH_MODULE(RNN); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -144,10 +163,12 @@ class LSTMImpl : public detail::RNNImplBase { : LSTMImpl(LSTMOptions(input_size, hidden_size)) {} explicit LSTMImpl(LSTMOptions options); - private: - Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override; + RNNOutput forward(Tensor input, Tensor state = {}); }; +/// A multi-layer long-short-term-memory (LSTM) module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM for more +/// documenation. TORCH_MODULE(LSTM); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -160,10 +181,12 @@ class GRUImpl : public detail::RNNImplBase { : GRUImpl(GRUOptions(input_size, hidden_size)) {} explicit GRUImpl(GRUOptions options); - private: - Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override; + RNNOutput forward(Tensor input, Tensor state = {}); }; +/// A multi-layer gated recurrent unit (GRU) module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU for more +/// documenation. TORCH_MODULE(GRU); } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp index 3bc1ae5fa6faf4..e0f6507f4c8c22 100644 --- a/torch/csrc/api/src/nn/modules/rnn.cpp +++ b/torch/csrc/api/src/nn/modules/rnn.cpp @@ -20,22 +20,6 @@ namespace torch { namespace nn { -namespace { -Tensor linear(Tensor x, Tensor w, Tensor b) { - if (x.ndimension() == 2 && b.defined()) { - // Fused op is marginally faster - assert(x.size(1) == w.size(1)); - return torch::addmm(b, x, w.t()); - } - - auto output = x.matmul(w.t()); - if (b.defined()) { - output += b; - } - return output; -} -} // namespace - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNOptionsBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ namespace detail { @@ -48,22 +32,15 @@ template RNNImplBase::RNNImplBase( RNNOptionsBase options_, at::optional cudnn_mode, - int64_t number_of_gates, - bool has_cell_state) + int64_t number_of_gates) : options(options_), - dropout(nullptr), number_of_gates_(number_of_gates), - has_cell_state_(has_cell_state), cudnn_mode_(cudnn_mode) { reset(); } template void RNNImplBase::reset() { - if (options.dropout_ > 0.0) { - dropout = Dropout(options.dropout_); - } - w_ih.resize(options.layers_); w_hh.resize(options.layers_); b_ih.resize(options.layers_); @@ -89,209 +66,113 @@ void RNNImplBase::reset() { } } - const auto stdv = 1.0 / std::sqrt(options.hidden_size_); - NoGradGuard no_grad;; - for (auto& p : this->parameters()) { - p->uniform_(-stdv, stdv); + { + NoGradGuard no_grad; + const auto stdv = 1.0 / std::sqrt(options.hidden_size_); + for (auto& p : this->parameters()) { + p->uniform_(-stdv, stdv); + } } -} -template -RNNOutput RNNImplBase::forward(Tensor input, Tensor state) { - if (use_cudnn(/*sample=*/input)) { - return CUDNN_forward(input, state); - } else { - return autograd_forward(input, state); - } + flatten_parameters(); } template -std::vector RNNImplBase::flat_weights() const { - std::vector flat; - for (int64_t layer = 0; layer < options.layers_; layer++) { - flat.push_back(w_ih[layer]); - flat.push_back(w_hh[layer]); - if (options.with_bias_) { - flat.push_back(b_ih[layer]); - flat.push_back(b_hh[layer]); - } - } - return flat; +void RNNImplBase::to( + torch::Device device, + torch::Dtype dtype, + bool non_blocking) { + nn::Module::to(device, dtype, non_blocking); + flatten_parameters(); } template -bool RNNImplBase::use_cudnn(Tensor sample) const { - return cudnn_mode_.has_value() && sample.is_cuda() && - torch::cudnn_is_acceptable(sample); +void RNNImplBase::to(torch::Dtype dtype, bool non_blocking) { + nn::Module::to(dtype, non_blocking); + flatten_parameters(); } template -Tensor RNNImplBase::create_dropout_state(Tensor input) const { - static const int64_t dropout_seed = - torch::empty({}, torch::kInt64).random_().toCLong(); - if (options.dropout_ > 0) { - torch::DeviceGuard guard(input.device()); - return torch::_cudnn_init_dropout_state( - input.type().toScalarType(torch::kUInt8), - options.dropout_, - this->is_training(), - dropout_seed); - } - return torch::empty({}, input.options()); +void RNNImplBase::to(torch::Device device, bool non_blocking) { + nn::Module::to(device, non_blocking); + flatten_parameters(); } template -RNNOutput RNNImplBase::autograd_forward(Tensor input, Tensor state) { - std::vector new_state; - auto has_hidden = state.defined(); - auto layer_dimension = has_hidden ? state.ndimension() - 3 : -1; - for (int64_t layer = 0; layer < options.layers_; layer++) { - new_state.push_back( - has_hidden ? state.select(layer_dimension, layer) : Tensor()); - } +void RNNImplBase::flatten_parameters() { + // Cache the flattened weight and bias vector. + flat_weights_ = flat_weights(); - auto output = torch::zeros( - {input.size(0), input.size(1), options.hidden_size_}, input.options()); - for (int64_t t = 0; t < input.size(0); t++) { - auto x = input.select(0, t); - for (int64_t i = 0; i < options.layers_; i++) { - // cell_forward() returns a stacked tensor of one or more cell states. - auto layer_output = cell_forward(x, new_state[i], i); - // If there are multiple cell states, keep all. If there is only one, - // the first dimension will be 1, so `.squeeze(0)` will unpack it. - new_state[i] = layer_output.squeeze(0); - // x should always be the hidden cell state h, assumed to be the zero-th. - x = layer_output[0]; - output.select(0, t).copy_(x); - if (options.dropout_ > 0 && i != options.layers_ - 1) { - x = dropout->forward(x); - } - } + if (!cudnn_mode_ || !torch::cudnn_is_acceptable(/*sample=*/w_ih.at(0))) { + return; } - auto state_output = torch::stack(new_state); - if (has_cell_state_) { - state_output.transpose_(0, 1); - } - return {output, state_output}; + NoGradGuard no_grad; + torch::_cudnn_rnn_flatten_weight( + flat_weights_, + /*weight_stride=*/options.with_bias_ ? 4 : 2, + options.input_size_, + static_cast(*cudnn_mode_), + options.hidden_size_, + options.layers_, + /*batch_first=*/options.batch_first_, + /*bidirectional=*/options.bidirectional_); } template -void RNNImplBase::flatten_parameters_for_cudnn() { - data_ptrs_.clear(); - const auto any_parameter = w_ih.at(0); - if (!use_cudnn(/*sample=*/w_ih.at(0))) { - return; - } - std::unordered_set unique_data_ptrs; - auto params = this->parameters(); - for (auto& p : params) { - unique_data_ptrs.insert(p->data_ptr()); - } - // TODO PyTorch says: If any parameters alias, we fall back to the slower, - // copying code path. This is a sufficient check, because overlapping - // parameter buffers that don't completely alias would break the assumptions - // of the uniqueness check in Module.named_parameters(). But I'm not sure if - // this is the case for us - if (unique_data_ptrs.size() != params.size()) { - return; - } - - { - NoGradGuard no_grad;; - flat_weights_ = torch::_cudnn_rnn_flatten_weight( - flat_weights(), - /*weight_stride=*/options.with_bias_ ? 4 : 2, - options.input_size_, - static_cast(*cudnn_mode_), - options.hidden_size_, - options.layers_, - /*batch_first=*/false, - /*bidirectional=*/false); - } - for (auto& p : params) { - data_ptrs_.emplace_back(p->data_ptr()); +RNNOutput RNNImplBase::generic_forward( + std::function function, + Tensor input, + Tensor state) { + if (!state.defined()) { + // #layers, batch size, state size + const auto batch_size = input.size(options.batch_first_ ? 0 : 1); + state = torch::zeros( + {options.layers_, batch_size, options.hidden_size_}, input.options()); } + Tensor output, new_state; + std::tie(output, new_state) = function( + std::move(input), + std::move(state), + flat_weights_, + options.with_bias_, + options.layers_, + options.dropout_, + this->is_training(), + options.bidirectional_, + options.batch_first_); + return {output, new_state}; } template -RNNOutput RNNImplBase::CUDNN_forward(Tensor input, Tensor state) { - Tensor hx, cx; - if (state.defined()) { - if (has_cell_state_) { - hx = state[0]; - cx = state[1]; - } else { - hx = state; - } - } else { - hx = torch::zeros( - {options.layers_, input.size(1), options.hidden_size_}, - input.options()); - if (has_cell_state_) { - cx = torch::zeros( - {options.layers_, input.size(1), options.hidden_size_}, - input.options()); +std::vector RNNImplBase::flat_weights() const { + // Organize all weights in a flat vector in the order + // (w_ih, w_hh, b_ih, b_hh), repeated for each layer (next to each other). + std::vector flat; + for (int64_t layer = 0; layer < options.layers_; layer++) { + flat.push_back(w_ih[layer]); + flat.push_back(w_hh[layer]); + if (options.with_bias_) { + flat.push_back(b_ih[layer]); + flat.push_back(b_hh[layer]); } } - std::vector weight_data_ptrs; - for (auto& p : this->parameters()) { - weight_data_ptrs.emplace_back(p->data_ptr()); - } - - AT_CHECK( - weight_data_ptrs == data_ptrs_, - "Parameters are unflattened! Code path might be super slow. " - "Please call flatten_parameters_for_cudnn() when you muck " - "around with storages!") - - // cudnn_output = std::tuple - auto cudnn_output = torch::_cudnn_rnn( - /*input=*/input, - /*weight=*/flat_weights(), - /*weight_stride0=*/options.with_bias_ ? 4 : 2, - /*weight_buf=*/flat_weights_, - /*hx=*/hx, - /*cx=*/cx, - /*mode=*/static_cast(*cudnn_mode_), - /*hidden_size=*/options.hidden_size_, - /*num_layers=*/options.layers_, - /*batch_first=*/false, - /*dropout=*/options.dropout_, - /*train=*/this->is_training(), - /*bidirectional=*/false, - /*batch_sizes=*/{}, - /*dropout_state=*/create_dropout_state(input)); - - Tensor hidden_output = std::get<1>(cudnn_output); - if (has_cell_state_) { - auto cy = std::get<2>(cudnn_output); - hidden_output = torch::stack({hidden_output, cy}); - } - - Tensor output = std::get<0>(cudnn_output); - return {output, hidden_output}; -} - -template -void RNNImplBase::to( - torch::Device device, - torch::Dtype dtype, - bool non_blocking) { - nn::Module::to(device, dtype, non_blocking); - flatten_parameters_for_cudnn(); + return flat; } template -void RNNImplBase::to(torch::Dtype dtype, bool non_blocking) { - nn::Module::to(dtype, non_blocking); - flatten_parameters_for_cudnn(); -} +bool RNNImplBase::any_parameters_alias() const { + // If any parameters alias, we fall back to the slower, copying code path. + // This is a sufficient check, because overlapping parameter buffers that + // don't completely alias would break the assumptions of the uniqueness check + // in Module.named_parameters(). + std::unordered_set unique_data_ptrs; + const auto params = this->parameters(); + params.map( + std::inserter(unique_data_ptrs, unique_data_ptrs.end()), + [](Tensor p) { return p.data_ptr(); }); -template -void RNNImplBase::to(torch::Device device, bool non_blocking) { - nn::Module::to(device, non_blocking); - flatten_parameters_for_cudnn(); + return unique_data_ptrs.size() != params.size(); } template class RNNImplBase; @@ -317,61 +198,59 @@ RNNImpl::RNNImpl(RNNOptions options) detail::RNNOptionsBase(options.input_size_, options.hidden_size_) .layers(options.layers_) .with_bias(options.with_bias_) - .dropout(options.dropout_), - /*cudnn_mode=*/static_cast(options.activation_)), - options(options) { + .dropout(options.dropout_) + .bidirectional(options.bidirectional_) + .batch_first(options.batch_first_), + static_cast(options.activation_)), + options(options) {} + +RNNOutput RNNImpl::forward(Tensor input, Tensor state) { switch (options.activation_) { - case RNNActivation::ReLU: { - activation_function_ = torch::relu; - break; - } - case RNNActivation::Tanh: { - activation_function_ = torch::tanh; - break; - } + case RNNActivation::ReLU: + return generic_forward( + static_cast(&torch::rnn_relu), input, state); + case RNNActivation::Tanh: + return generic_forward( + static_cast(&torch::rnn_tanh), input, state); + default: + AT_ERROR("Unhandled RNN activation function!"); } } -Tensor RNNImpl::cell_forward(Tensor input, Tensor state, int64_t layer) { - auto hx = state.defined() - ? state - : torch::zeros({input.size(0), options.hidden_size_}, input.options()); - - auto h = linear(input, w_ih[layer], b_ih[layer]) + - linear(hx, w_hh[layer], b_hh[layer]); - - return torch::stack(activation_function_(h)); -} - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMImpl::LSTMImpl(LSTMOptions options) : detail::RNNImplBase( options, - /*cudnn_mode=*/CuDNNMode::LSTM, - /*number_of_gates=*/4, - /*has_cell_state=*/true) {} - -Tensor LSTMImpl::cell_forward(Tensor input, Tensor state, int64_t layer) { - auto hid = state.defined() - ? state - : torch::zeros({2, input.size(0), options.hidden_size_}, input.options()); - auto hx = hid[0]; - auto cx = hid[1]; - - auto gates = linear(input, w_ih[layer], b_ih[layer]) + - linear(hx, w_hh[layer], b_hh[layer]); - - auto chunked = gates.chunk(4, 1); - auto in_gate = chunked[0].sigmoid(); - auto forget_gate = chunked[1].sigmoid(); - auto cell_gate = chunked[2].tanh(); - auto out_gate = chunked[3].sigmoid(); - - auto cy = (forget_gate * cx) + (in_gate * cell_gate); - auto hy = out_gate * cy.tanh(); - - return torch::stack({hy, cy}, 0); + CuDNNMode::LSTM, + /*number_of_gates=*/4) {} + +RNNOutput LSTMImpl::forward(Tensor input, Tensor state) { + // It would be trickier to adapt the `generic_forward` for the LSTM because + // its output has a different dimensionality (3-tuple vs. 2-tuple), while we + // always return one state variable (stacking the hidden/cell state into one), + // which also makes the state variables going into the `generic_forward`, and + // the way we default-initialize the state when it is not passed, slightly + // different. So we just re-implement it specifically for the LSTM here. + if (!state.defined()) { + // 2 for hidden state and cell state, then #layers, batch size, state size + const auto batch_size = input.size(options.batch_first_ ? 0 : 1); + state = torch::zeros( + {2, options.layers_, batch_size, options.hidden_size_}, + input.options()); + } + Tensor output, hidden_state, cell_state; + std::tie(output, hidden_state, cell_state) = torch::lstm( + input, + {state[0], state[1]}, + flat_weights_, + options.with_bias_, + options.layers_, + options.dropout_, + this->is_training(), + options.bidirectional_, + options.batch_first_); + return {output, torch::stack({hidden_state, cell_state})}; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -379,25 +258,12 @@ Tensor LSTMImpl::cell_forward(Tensor input, Tensor state, int64_t layer) { GRUImpl::GRUImpl(GRUOptions options) : detail::RNNImplBase( options, - /*cudnn_mode=*/CuDNNMode::GRU, + CuDNNMode::GRU, /*number_of_gates=*/3) {} -Tensor GRUImpl::cell_forward(Tensor input, Tensor state, int64_t layer) { - auto hx = state.defined() - ? state - : torch::zeros({input.size(0), options.hidden_size_}, input.options()); - - auto gi = linear(input, w_ih[layer], b_ih[layer]); - auto gh = linear(input, w_hh[layer], b_hh[layer]); - auto gic = gi.chunk(3, 1); - auto ghc = gh.chunk(3, 1); - - auto reset_gate = (gic[0] + ghc[0]).sigmoid_(); - auto input_gate = (gic[1] + ghc[1]).sigmoid_(); - auto new_gate = (gic[2] + reset_gate * ghc[2]).tanh_(); - auto hy = new_gate + input_gate * (hx - new_gate); - - return torch::stack(hy); +RNNOutput GRUImpl::forward(Tensor input, Tensor state) { + return generic_forward( + static_cast(&torch::gru), input, state); } } // namespace nn } // namespace torch diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 20ff911ecd2090..963fd4d08e8dad 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -101,7 +101,7 @@ def flatten_parameters(self): with torch.no_grad(): # NB: this is an INPLACE function on weight_arr, that's why the # no_grad() is necessary. - weight_buf = torch._cudnn_rnn_flatten_weight( + torch._cudnn_rnn_flatten_weight( weight_arr, weight_stride0, self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers, self.batch_first, bool(self.bidirectional))