Skip to content

Commit

Permalink
Add --upsample-primary
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed Sep 3, 2018
1 parent 5852d3a commit 6296de8
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion fairseq/tasks/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def add_args(parser):
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset')

def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
Expand Down Expand Up @@ -120,12 +122,14 @@ def indexed_dataset(path, dictionary):
src_sizes = src_dataset.sizes
tgt_sizes = tgt_dataset.sizes
else:
if self.args.upsample_primary > 1:
src_datasets.extend([src_datasets[0]] * (self.args.upsample_primary - 1))
tgt_datasets.extend([tgt_datasets[0]] * (self.args.upsample_primary - 1))
src_dataset = ConcatDataset(src_datasets)
tgt_dataset = ConcatDataset(tgt_datasets)
src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])


self.datasets[split] = LanguagePairDataset(
src_dataset, src_sizes, self.src_dict,
tgt_dataset, tgt_sizes, self.tgt_dict,
Expand Down

0 comments on commit 6296de8

Please sign in to comment.