Skip to content

Commit

Permalink
Handle 3+ dimensional input in sequence_generator + nits
Browse files Browse the repository at this point in the history
Summary: sequence_generator assumes that model input is 2d tensor of longs. But it can be something like 3d tensor of floats and we should be able to handle this as long as first dimension is batch size followed by source lengths.

Reviewed By: myleott

Differential Revision: D14420044

fbshipit-source-id: bf8b1e42ad1873f7b803c1a377b0af21648db015
  • Loading branch information
Dmytro Okhonko authored and facebook-github-bot committed Mar 12, 2019
1 parent d17fa85 commit 860010e
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 6 deletions.
4 changes: 4 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def add_dataset_args(parser, train=False, gen=False):
help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch')
group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N',
help='batch size will be a multiplier of this value')
if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
Expand Down Expand Up @@ -359,6 +361,8 @@ def add_common_eval_args(group):
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation '
'that were used during model training')
group.add_argument('--results-path', metavar='RESDIR', type=str, default=None,
help='path to save eval results (optional)"')
# fmt: on


Expand Down
7 changes: 5 additions & 2 deletions fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ def generate(

src_tokens = encoder_input['src_tokens']
src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
bsz, src_len = src_tokens.size()
input_size = src_tokens.size()
# batch dimension goes first followed by source lengths
bsz = input_size[0]
src_len = input_size[1]
beam_size = self.beam_size

if self.match_source_len:
Expand All @@ -148,7 +151,7 @@ def generate(
# initialize buffers
scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
scores_buf = scores.clone()
tokens = src_tokens.new(bsz * beam_size, max_len + 2).fill_(self.pad)
tokens = src_tokens.data.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = bos_token or self.eos
attn, attn_buf = None, None
Expand Down
2 changes: 1 addition & 1 deletion fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def train_step(self, samples, dummy_batch=False):
sample_sizes.append(sample_size)
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
print(('| WARNING: ran out of memory with exception: {};\n Skipping batch').format(str(e)))
ooms += 1
self.zero_grad()
else:
Expand Down
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(args):
*[model.max_positions() for model in models]
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
Expand Down
1 change: 1 addition & 0 deletions tests/test_binaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def test_optimizers(self):
if os.path.exists(last_checkpoint):
os.remove(last_checkpoint)
train_translation_model(data_dir, 'lstm', [
'--required-batch-size-multiple', '1',
'--encoder-layers', '1',
'--encoder-hidden-size', '32',
'--decoder-layers', '1',
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def main(args, init_distributed=False):
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=8,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
Expand Down Expand Up @@ -220,7 +220,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
trainer.get_model().max_positions(),
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
Expand Down

0 comments on commit 860010e

Please sign in to comment.