Skip to content

Commit

Permalink
Implement bilinear double backward.
Browse files Browse the repository at this point in the history
  • Loading branch information
gchanan authored and soumith committed Jul 3, 2017
1 parent 1aa145d commit daa84e7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 23 deletions.
10 changes: 5 additions & 5 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _get_parameters(self, module):
d_params.append(p.grad.data)
return params, d_params

def _assert_grad_and_gradgradchecks(self, apply_fn, inputs):
def _assertGradAndGradgradChecks(self, apply_fn, inputs):
self.assertTrue(gradcheck(apply_fn, inputs))
dummy_out = apply_fn(*inputs)
if isinstance(dummy_out, tuple):
Expand Down Expand Up @@ -981,9 +981,9 @@ def test_InstanceNorm3d(self):

def test_pad(self):
inputs = Variable(torch.randn(1, 3, 4, 4), requires_grad=True)
self._assert_grad_and_gradgradchecks(lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,))
self._assert_grad_and_gradgradchecks(lambda x: F.pad(x, (-1, 1, -2, 1)), (inputs,))
self._assert_grad_and_gradgradchecks(lambda x: F.pad(x, (-1, 1, -2, 1), value=2), (inputs,))
self._assertGradAndGradgradChecks(lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,))
self._assertGradAndGradgradChecks(lambda x: F.pad(x, (-1, 1, -2, 1)), (inputs,))
self._assertGradAndGradgradChecks(lambda x: F.pad(x, (-1, 1, -2, 1), value=2), (inputs,))
self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='replicate'), (inputs,)))
self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='reflect'), (inputs,)))

Expand Down Expand Up @@ -2388,7 +2388,7 @@ def test_bilinear(self):
self.assertEqual(module.weight.grad.data, module_legacy.gradWeight)
self.assertEqual(module.bias.grad.data, module_legacy.gradBias)

self.assertTrue(gradcheck(lambda x1, x2: F.bilinear(x1, x2, module.weight, module.bias), (input1_1, input2_1)))
self._assertGradAndGradgradChecks(lambda x1, x2: F.bilinear(x1, x2, module.weight, module.bias), (input1_1, input2_1))

def run_conv_double_back_test(self, kern, stride, padding, chan_in, chan_out, batch_size,
inp_size, dilation, no_weight, use_cuda=False, use_bias=True):
Expand Down
34 changes: 18 additions & 16 deletions torch/nn/_functions/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

class Bilinear(Function):

def forward(self, input1, input2, weight, bias=None):
self.save_for_backward(input1, input2, weight, bias)
@staticmethod
def forward(ctx, input1, input2, weight, bias=None):
ctx.save_for_backward(input1, input2, weight, bias)

output = input1.new(input1.size(0), weight.size(0))

Expand All @@ -22,35 +23,36 @@ def forward(self, input1, input2, weight, bias=None):

return output

def backward(self, grad_output):
input1, input2, weight, bias = self.saved_tensors
@staticmethod
def backward(ctx, grad_output):
input1, input2, weight, bias = ctx.saved_variables
grad_input1 = grad_input2 = grad_weight = grad_bias = None

buff = input1.new()
buff = Variable(input1.data.new())

if self.needs_input_grad[0] or self.needs_input_grad[1]:
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input1 = torch.mm(input2, weight[0].t())
grad_input1.mul_(grad_output.narrow(1, 0, 1).expand(grad_input1.size()))
grad_input1 = grad_input1.mul(grad_output.narrow(1, 0, 1).expand(grad_input1.size()))
grad_input2 = torch.mm(input1, weight[0])
grad_input2.mul_(grad_output.narrow(1, 0, 1).expand(grad_input2.size()))
grad_input2 = grad_input2.mul(grad_output.narrow(1, 0, 1).expand(grad_input2.size()))

for k in range(1, weight.size(0)):
torch.mm(input2, weight[k].t(), out=buff)
buff.mul_(grad_output.narrow(1, k, 1).expand(grad_input1.size()))
buff = input2.mm(weight[k].t())
buff = buff.mul(grad_output.narrow(1, k, 1).expand(grad_input1.size()))
grad_input1.add_(buff)

torch.mm(input1, weight[k], out=buff)
buff.mul_(grad_output.narrow(1, k, 1).expand(grad_input2.size()))
buff = input1.mm(weight[k])
buff = buff.mul(grad_output.narrow(1, k, 1).expand(grad_input2.size()))
grad_input2.add_(buff)

grad_weight = weight.new(weight.size())
if self.needs_input_grad[2]:
grad_weight = Variable(weight.data.new(weight.size()))
if ctx.needs_input_grad[2]:
# accumulate parameter gradients:
for k in range(weight.size(0)):
torch.mul(input1, grad_output.narrow(1, k, 1).expand_as(input1), out=buff)
buff = input1.mul(grad_output.narrow(1, k, 1).expand_as(input1))
grad_weight[k] = torch.mm(buff.t(), input2)

if bias is not None and self.needs_input_grad[3]:
if bias is not None and ctx.needs_input_grad[3]:
grad_bias = grad_output.sum(0, keepdim=False)

return grad_input1, grad_input2, grad_weight, grad_bias
4 changes: 2 additions & 2 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,9 @@ def linear(input, weight, bias=None):

def bilinear(input1, input2, weight, bias=None):
if bias is None:
return Bilinear()(input1, input2, weight)
return Bilinear.apply(input1, input2, weight)
else:
return Bilinear()(input1, input2, weight, bias)
return Bilinear.apply(input1, input2, weight, bias)


def batch_norm(input, running_mean, running_var, weight=None, bias=None,
Expand Down

0 comments on commit daa84e7

Please sign in to comment.