Skip to content

Commit

Permalink
Support deprecation of volatile Variables in latest PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed Jan 22, 2018
1 parent 5637d54 commit 7da4e06
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
29 changes: 15 additions & 14 deletions fairseq/multiprocessing_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,20 +227,21 @@ def _async_forward(self, rank, device_id, eval=False):
self.model.train()
self.optimizer.zero_grad()

sample_size, logging_output, oom = 0, {}, False
if self._sample is not None:
try:
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
self.loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e
with utils.maybe_no_grad(eval):
sample_size, logging_output, oom = 0, {}, False
if self._sample is not None:
try:
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
self.loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e

return sample_size, logging_output, oom

Expand Down
8 changes: 8 additions & 0 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
#

import contextlib
import logging
import os
import torch
Expand Down Expand Up @@ -244,3 +245,10 @@ def rstrip_pad(tensor, pad):
if strip > 0:
return tensor[:-strip]
return tensor


def maybe_no_grad(condition):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
2 changes: 2 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def main():
print(args)

use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)

# Load dataset
if args.replace_unk is None:
Expand Down

0 comments on commit 7da4e06

Please sign in to comment.