Skip to content

Commit

Permalink
fix eos token for batched case
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 8, 2021
1 parent 4e94794 commit e3cb1af
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
14 changes: 12 additions & 2 deletions h_transformer_1d/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

# helper function

def exists(val):
return val is not None

def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
Expand Down Expand Up @@ -56,8 +59,15 @@ def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., fi

out = torch.cat((out, sample), dim=-1)

if eos_token is not None and (sample == eos_token).all():
break
if exists(eos_token):
is_eos_token = (out == eos_token)

if is_eos_token.any(dim = -1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
break

out = out[:, t:]

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'h-transformer-1d',
packages = find_packages(),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'H-Transformer 1D - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e3cb1af

Please sign in to comment.