Skip to content

Commit

Permalink
Document the Conv module (pytorch#11566)
Browse files Browse the repository at this point in the history
Summary:
Document the C++ API conv module. No code changes.

ebetica ezyang soumith
Pull Request resolved: pytorch#11566

Differential Revision: D9793665

Pulled By: goldsborough

fbshipit-source-id: 5f7f0605f952fadc62ffbcb8eca4183d4142c451
  • Loading branch information
goldsborough authored and facebook-github-bot committed Sep 12, 2018
1 parent 130d55a commit 5b2efcf
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 8 deletions.
81 changes: 80 additions & 1 deletion torch/csrc/api/include/torch/nn/modules/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,70 +10,149 @@

namespace torch {
namespace nn {

/// Options for a `D`-dimensional convolution module.
template <size_t D>
struct ConvOptions {
ConvOptions(
int64_t input_channels,
int64_t output_channels,
ExpandingArray<D> kernel_size);

/// The number of channels the input volumes will have.
/// Changing this parameter after construction __has no effect__.
TORCH_ARG(int64_t, input_channels);

/// The number of output channels the convolution should produce.
/// Changing this parameter after construction __has no effect__.
TORCH_ARG(int64_t, output_channels);

/// The kernel size to use.
/// For a `D`-dim convolution, must be a single number or a list of `D`
/// numbers.
/// This parameter __can__ be changed after construction.
TORCH_ARG(ExpandingArray<D>, kernel_size);

/// The stride of the convolution.
/// For a `D`-dim convolution, must be a single number or a list of `D`
/// numbers.
/// This parameter __can__ be changed after construction.
TORCH_ARG(ExpandingArray<D>, stride) = 1;

/// The padding to add to the input volumes.
/// For a `D`-dim convolution, must be a single number or a list of `D`
/// numbers.
/// This parameter __can__ be changed after construction.
TORCH_ARG(ExpandingArray<D>, padding) = 0;

/// The kernel dilation.
/// For a `D`-dim convolution, must be a single number or a list of `D`
/// numbers.
/// This parameter __can__ be changed after construction.
TORCH_ARG(ExpandingArray<D>, dilation) = 1;

/// For transpose convolutions, the padding to add to output volumes.
/// For a `D`-dim convolution, must be a single number or a list of `D`
/// numbers.
/// This parameter __can__ be changed after construction.
TORCH_ARG(ExpandingArray<D>, output_padding) = 0;

/// If true, convolutions will be transpose convolutions (a.k.a.
/// deconvolutions).
/// Changing this parameter after construction __has no effect__.
TORCH_ARG(bool, transposed) = false;

/// Whether to add a bias after individual applications of the kernel.
/// Changing this parameter after construction __has no effect__.
TORCH_ARG(bool, with_bias) = true;

/// The number of convolution groups.
/// This parameter __can__ be changed after construction.
TORCH_ARG(int64_t, groups) = 1;
};

/// Base class for all (dimension-specialized) convolution modules.
template <size_t D, typename Derived>
class ConvImpl : public torch::nn::Cloneable<Derived> {
public:
ConvImpl(
int64_t input_channels,
int64_t output_channels,
ExpandingArray<D> kernel_size)
: ConvImpl(ConvOptions<D>(input_channels, output_channels, kernel_size)) {}
: ConvImpl(ConvOptions<D>(input_channels, output_channels, kernel_size)) {
}
explicit ConvImpl(ConvOptions<D> options);

void reset() override;

/// The options with which this `Module` was constructed.
ConvOptions<D> options;

/// The learned kernel (or "weight").
Tensor weight;

/// The learned bias. Only defined if the `with_bias` option was true.
Tensor bias;
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Applies convolution over a 1-D input.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.Conv1d to learn about
/// the exact behavior of this module.
class Conv1dImpl : public ConvImpl<1, Conv1dImpl> {
public:
using ConvImpl<1, Conv1dImpl>::ConvImpl;
Tensor forward(Tensor input);
};

/// `ConvOptions` specialized for 1-D convolution.
using Conv1dOptions = ConvOptions<1>;

/// A `ModuleHolder` subclass for `Conv1dImpl`.
/// See the documentation for `Conv1dImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(Conv1d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Applies convolution over a 2-D input.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d to learn about
/// the exact behavior of this module.
class Conv2dImpl : public ConvImpl<2, Conv2dImpl> {
public:
using ConvImpl<2, Conv2dImpl>::ConvImpl;
Tensor forward(Tensor input);
};

/// `ConvOptions` specialized for 2-D convolution.
using Conv2dOptions = ConvOptions<2>;

/// A `ModuleHolder` subclass for `Conv2dImpl`.
/// See the documentation for `Conv2dImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(Conv2d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Applies convolution over a 3-D input.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.Conv3d to learn about
/// the exact behavior of this module.
class Conv3dImpl : public ConvImpl<3, Conv3dImpl> {
public:
using ConvImpl<3, Conv3dImpl>::ConvImpl;
Tensor forward(Tensor input);
};

/// `ConvOptions` specialized for 3-D convolution.
using Conv3dOptions = ConvOptions<3>;

/// A `ModuleHolder` subclass for `Conv3dImpl`.
/// See the documentation for `Conv3dImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(Conv3d);

} // namespace nn
Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/api/include/torch/nn/modules/rnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ struct RNNOptions {
};

/// A multi-layer Elman RNN module with Tanh or ReLU activation.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN for more
/// documenation.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN to learn about the
/// exact behavior of this module.
class RNNImpl : public detail::RNNImplBase<RNNImpl> {
public:
RNNImpl(int64_t input_size, int64_t hidden_size)
Expand All @@ -198,8 +198,8 @@ TORCH_MODULE(RNN);
using LSTMOptions = detail::RNNOptionsBase;

/// A multi-layer long-short-term-memory (LSTM) module.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM for more
/// documenation.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM to learn about the
/// exact behavior of this module.
class LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
public:
LSTMImpl(int64_t input_size, int64_t hidden_size)
Expand All @@ -224,8 +224,8 @@ TORCH_MODULE(LSTM);
using GRUOptions = detail::RNNOptionsBase;

/// A multi-layer gated recurrent unit (GRU) module.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU for more
/// documenation.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn about the
/// exact behavior of this module.
class GRUImpl : public detail::RNNImplBase<GRUImpl> {
public:
GRUImpl(int64_t input_size, int64_t hidden_size)
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/api/src/nn/modules/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cmath>
#include <cstdint>
#include <functional>
#include <utility>
#include <vector>

namespace torch {
Expand Down Expand Up @@ -61,7 +62,7 @@ void ConvImpl<D, Derived>::reset() {
options.input_channels_,
std::multiplies<int64_t>{});
const auto stdv = 1.0 / std::sqrt(number_of_features);
NoGradGuard no_grad;;
NoGradGuard no_grad;
for (auto& p : this->parameters()) {
p->uniform_(-stdv, stdv);
}
Expand Down

0 comments on commit 5b2efcf

Please sign in to comment.