Skip to content

Commit

Permalink
Check dropout probability
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed Jan 16, 2017
1 parent 676ffee commit 59bc96b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 11 deletions.
12 changes: 12 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functools import wraps

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel as dp
from torch.autograd import Variable
from torch.nn import Parameter
Expand Down Expand Up @@ -877,6 +878,17 @@ def test_RNN_cell(self):

hx.sum().backward()

def test_invalid_dropout_p(self):
v = Variable(torch.ones(1))
self.assertRaises(ValueError, lambda: nn.Dropout(-0.1))
self.assertRaises(ValueError, lambda: nn.Dropout(1.1))
self.assertRaises(ValueError, lambda: nn.Dropout2d(-0.1))
self.assertRaises(ValueError, lambda: nn.Dropout2d(1.1))
self.assertRaises(ValueError, lambda: nn.Dropout3d(-0.1))
self.assertRaises(ValueError, lambda: nn.Dropout3d(1.1))
self.assertRaises(ValueError, lambda: F.dropout(v, -0.1))
self.assertRaises(ValueError, lambda: F.dropout(v, 1.1))

def test_LSTM_cell(self):
# this is just a smoke test; these modules are implemented through
# autograd so no Jacobian test is needed
Expand Down
3 changes: 3 additions & 0 deletions torch/nn/_functions/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ class Dropout(InplaceFunction):

def __init__(self, p=0.5, train=False, inplace=False):
super(Dropout, self).__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.train = train
self.inplace = inplace
Expand Down
31 changes: 20 additions & 11 deletions torch/nn/modules/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Dropout(Module):
Shape:
- Input: `Any`. Input can be of any shape
- Output: `Same`. Output is of the same shape as input
- Output: `Same`. Output is of the same shape as input
Examples::
Expand All @@ -21,6 +21,9 @@ class Dropout(Module):
"""
def __init__(self, p=0.5, inplace=False):
super(Dropout, self).__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.inplace = inplace

Expand All @@ -40,11 +43,11 @@ class Dropout2d(Module):
*Usually the input comes from Conv2d modules.*
As described in the paper
`Efficient Object Localization Using Convolutional Networks`_ ,
if adjacent pixels within feature maps are strongly correlated
(as is normally the case in early convolution layers) then iid dropout
will not regularize the activations and will otherwise just result
As described in the paper
`Efficient Object Localization Using Convolutional Networks`_ ,
if adjacent pixels within feature maps are strongly correlated
(as is normally the case in early convolution layers) then iid dropout
will not regularize the activations and will otherwise just result
in an effective learning rate decrease.
In this case, :func:`nn.Dropout2d` will help promote independence between
Expand All @@ -69,6 +72,9 @@ class Dropout2d(Module):
"""
def __init__(self, p=0.5, inplace=False):
super(Dropout2d, self).__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.inplace = inplace

Expand All @@ -87,11 +93,11 @@ class Dropout3d(Module):
*Usually the input comes from Conv3d modules.*
As described in the paper
`Efficient Object Localization Using Convolutional Networks`_ ,
if adjacent pixels within feature maps are strongly correlated
(as is normally the case in early convolution layers) then iid dropout
will not regularize the activations and will otherwise just result
As described in the paper
`Efficient Object Localization Using Convolutional Networks`_ ,
if adjacent pixels within feature maps are strongly correlated
(as is normally the case in early convolution layers) then iid dropout
will not regularize the activations and will otherwise just result
in an effective learning rate decrease.
In this case, :func:`nn.Dropout3d` will help promote independence between
Expand All @@ -116,6 +122,9 @@ class Dropout3d(Module):
"""
def __init__(self, p=0.5, inplace=False):
super(Dropout3d, self).__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.inplace = inplace

Expand Down

0 comments on commit 59bc96b

Please sign in to comment.