Skip to content

Commit

Permalink
Better support for torch.no_grad (since volatile is deprecated)
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed Jan 22, 2018
1 parent 0b84ab1 commit 907ca92
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
3 changes: 2 additions & 1 deletion fairseq/modules/linearized_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def incremental_forward(self, input):
# append next input
self.input_buffer[:, -1, :] = input[:, -1, :]
input = utils.volatile_variable(self.input_buffer)
output = F.linear(input.view(bsz, -1), weight, self.bias)
with utils.maybe_no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)

def clear_incremental_state(self):
Expand Down
8 changes: 5 additions & 3 deletions fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=
srclen = input['src_tokens'].size(1)
if timer is not None:
timer.start()
hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b))
with utils.maybe_no_grad():
hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None:
timer.stop(s['ntokens'])
for i, id in enumerate(s['id'].data):
Expand Down Expand Up @@ -327,7 +328,8 @@ def _decode(self, tokens, encoder_outs):
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, encoder_out)
with utils.maybe_no_grad():
decoder_out, attn = model.decoder(tokens, encoder_out)
probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
if avg_probs is None:
avg_probs = probs
Expand Down
9 changes: 5 additions & 4 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _upgrade_args(args):
return args


def maybe_no_grad(condition):
def maybe_no_grad(condition=True):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
Expand All @@ -185,9 +185,10 @@ def maybe_no_grad(condition):

def volatile_variable(*args, **kwargs):
if hasattr(torch, 'no_grad'):
with torch.no_grad():
return Variable(*args, **kwargs)
return Variable(*args, **kwargs, volatile=True)
# volatile has been deprecated, use the no_grad context manager instead
return Variable(*args, **kwargs)
else:
return Variable(*args, **kwargs, volatile=True)


def make_variable(sample, volatile=False, cuda_device=None):
Expand Down

0 comments on commit 907ca92

Please sign in to comment.