Skip to content

Commit

Permalink
Making our code compatible with the latest pytorch (facebookresearch#223
Browse files Browse the repository at this point in the history
)

* Making our code compatible with the latest pytorch

* revert

* torch.nn.utils.clip_grad_norm now returns tensor
  • Loading branch information
edunov authored and Sergey Edunov committed Feb 27, 2018
1 parent 9438019 commit 2f976aa
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 6 deletions.
4 changes: 2 additions & 2 deletions fairseq/criterions/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn.functional as F

from . import FairseqCriterion, register_criterion

from fairseq import utils

@register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion):
Expand All @@ -33,7 +33,7 @@ def forward(self, model, sample, reduce=True):
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0] if reduce else loss.data,
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
Expand Down
4 changes: 2 additions & 2 deletions fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def forward(self, model, sample, reduce=True):
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0] if reduce else loss.data,
'nll_loss': nll_loss.data[0] if reduce else loss.data,
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
Expand Down
2 changes: 1 addition & 1 deletion fairseq/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def all_gather_list(data, max_size=4096):
if len(enc) >= max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(len(enc)))
in_buffer[0] = len(enc)
in_buffer[1:len(enc)+1] = torch.ByteTensor(enc)
in_buffer[1:len(enc)+1] = torch.ByteTensor(list(enc))

torch.distributed.all_gather(out_buffers, in_buffer.cuda())

Expand Down
2 changes: 1 addition & 1 deletion fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _backward_and_opt(self, loss, grad_denom):

# clip grads
if self.args.clip_norm > 0:
grad_norm = torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm)
grad_norm = utils.item(torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm))
else:
grad_norm = math.sqrt(sum(p.grad.data.norm()**2 for p in self.model.parameters()))

Expand Down
7 changes: 7 additions & 0 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,10 @@ def convert_padding_direction(
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)

def item(tensor):
if hasattr(tensor, 'item'):
return tensor.item()
if hasattr(tensor, '__getitem__'):
return tensor[0]
return tensor

0 comments on commit 2f976aa

Please sign in to comment.