Skip to content

Commit

Permalink
be sure we have uint8
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Sep 4, 2019
1 parent 38b79b5 commit 0be6a2a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pytorch_transformers/modeling_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ def _forward(self, dec_inp, mems=None, head_mask=None):
mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
if self.same_length:
all_ones = word_emb.new_ones(qlen, klen)
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
Expand All @@ -1145,7 +1145,7 @@ def _forward(self, dec_inp, mems=None, head_mask=None):
+ torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None]
word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]

hids = []
attentions = []
Expand Down

0 comments on commit 0be6a2a

Please sign in to comment.