Skip to content

Commit

Permalink
Add "grouped_shuffling" for batch shuffling in groups of total workers (
Browse files Browse the repository at this point in the history
facebookresearch#2391)

Summary:
- Allows for faster training on multiple GPUs when batches are based on
sorted input sequences.

- Instead of shuffling batches randomly followed by distribution on
workers, we group the batches in sets of total workers and then shuffle
the groups. When the batches are sorted by length this ensures that each
worker receives similar length inputs.

# Before submitting

- [N] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [Y] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [N] Did you write any new necessary tests?

## What does this PR do?
Adds option "grouped_shuffling" to the dataclass to allow batches to be first grouped in set of total workers followed by shuffling of the groups. This reduces the sequence length discrepancy among the workers when the batches were created from inputs sorted by sequence lengths.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: fairinternal/fairseq-py#2391

Reviewed By: arbabu123

Differential Revision: D31352971

Pulled By: alexeib

fbshipit-source-id: c045bedecb03339c8eb46e7e8c9804a53b35615b
  • Loading branch information
Apoorv Vyas authored and facebook-github-bot committed Nov 16, 2021
1 parent 4ccb288 commit 89ec6e7
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 7 deletions.
5 changes: 4 additions & 1 deletion examples/MMPT/mmpt/tasks/fairseqmmtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def get_batch_iterator(
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
random.seed(epoch)
if dataset.mmdataset.split == "train" \
Expand All @@ -81,7 +83,8 @@ def get_batch_iterator(
dataset, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, required_batch_size_multiple,
seed, num_shards, shard_id, num_workers, epoch,
data_buffer_size, disable_iterator_cache)
data_buffer_size, disable_iterator_cache,
grouped_shuffling, update_epoch_batch_itr)

@property
def source_dictionary(self):
Expand Down
2 changes: 2 additions & 0 deletions examples/laser/laser_src/laser_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def get_batch_iterator(
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):

assert isinstance(dataset, OrderedDict)
Expand Down
3 changes: 3 additions & 0 deletions examples/speech_text_joint_to_text/tasks/speech_text_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ def get_batch_iterator(
epoch=0,
data_buffer_size=0,
disable_iterator_cache=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):

if not isinstance(dataset, MultiModalityDataset):
Expand All @@ -343,6 +345,7 @@ def get_batch_iterator(
epoch,
data_buffer_size,
disable_iterator_cache,
update_epoch_batch_itr=update_epoch_batch_itr,
)

mult_ratio = [self.args.speech_sample_ratio, self.args.text_sample_ratio]
Expand Down
24 changes: 21 additions & 3 deletions fairseq/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def __next__(self):
try:
x = next(self._itr)
except StopIteration:
raise IndexError(f"Iterator expected to have length {self.total}, "
"but exhausted at position {self.n}.")
raise IndexError(
f"Iterator expected to have length {self.total}, "
"but exhausted at position {self.n}."
)
self.n += 1
return x

Expand Down Expand Up @@ -263,6 +265,9 @@ class EpochBatchIterator(EpochBatchIterating):
from workers. Should always be non-negative (default: ``0``).
disable_shuffling (bool, optional): force disable shuffling
(default: ``False``).
grouped_shuffling (bool, optional): enable shuffling batches in groups
of num_shards. Ensures that each GPU receives similar length sequences when
batches are sorted by length.
"""

def __init__(
Expand All @@ -278,6 +283,7 @@ def __init__(
buffer_size=0,
timeout=0,
disable_shuffling=False,
grouped_shuffling=False,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
Expand All @@ -295,6 +301,7 @@ def __init__(
self.buffer_size = min(buffer_size, 20)
self.timeout = timeout
self.disable_shuffling = disable_shuffling
self.grouped_shuffling = grouped_shuffling

self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
self.shuffle = not disable_shuffling
Expand Down Expand Up @@ -433,7 +440,17 @@ def _get_iterator_for_epoch(
):
def shuffle_batches(batches, seed):
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)

if self.grouped_shuffling:
grouped_batches = [
batches[(i * self.num_shards) : ((i + 1) * self.num_shards)]
for i in range((len(batches) // self.num_shards))
]
np.random.shuffle(grouped_batches)
batches = list(itertools.chain(*grouped_batches))
else:
np.random.shuffle(batches)

return batches

if self._supports_prefetch:
Expand Down Expand Up @@ -639,6 +656,7 @@ def __next__(self):
raise StopIteration()
return item


class GroupedEpochBatchIterator(EpochBatchIterator):
"""Grouped version of EpochBatchIterator
It takes several samplers from different datasets.
Expand Down
18 changes: 18 additions & 0 deletions fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,24 @@ class DatasetConfig(FairseqDataclass):
shard_id: int = field(
default=0, metadata={"help": "id of the shard to generate (id < num_shards)"}
)
grouped_shuffling: bool = field(
default=False,
metadata={
"help": "shuffle batches in groups of num_shards to enable similar sequence lengths on each GPU worker when batches are sorted by length",
},
)
update_epoch_batch_itr: bool = field(
default=II("dataset.grouped_shuffling"),
metadata={
"help": "if true then prevents the reuse the epoch batch iterator by setting can_reuse_epoch_itr to false, defaults to --grouped-shuffling )",
},
)
update_ordered_indices_seed: bool = field(
default=False,
metadata={
"help": "if true then increment seed with epoch for getting batch iterators, defautls to False.",
}
)


@dataclass
Expand Down
7 changes: 7 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,13 @@ def parse_args_and_arch(
else:
args.no_seed_provided = False

if getattr(args, "update_epoch_batch_itr", None) is None:
if hasattr(args, "grouped_shuffling"):
args.update_epoch_batch_itr = args.grouped_shuffling
else:
args.grouped_shuffling = False
args.update_epoch_batch_itr = False

# Apply architecture configuration.
if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY:
ARCH_CONFIG_REGISTRY[args.arch](args)
Expand Down
15 changes: 13 additions & 2 deletions fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def get_batch_iterator(
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
Expand Down Expand Up @@ -252,12 +254,20 @@ def get_batch_iterator(
disable_iterator_cache (bool, optional): don't cache the
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
(default: False).
grouped_shuffling (bool, optional): group batches with each groups
containing num_shards batches and shuffle groups. Reduces difference
between sequence lengths among workers for batches sorted by length.
update_epoch_batch_itr (bool optional): if true then donot use the cached
batch iterator for the epoch
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
can_reuse_epoch_itr = not disable_iterator_cache and self.can_reuse_epoch_itr(
dataset
can_reuse_epoch_itr = (
not disable_iterator_cache
and not update_epoch_batch_itr
and self.can_reuse_epoch_itr(dataset)
)
if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter:
logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch))
Expand Down Expand Up @@ -297,6 +307,7 @@ def get_batch_iterator(
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size,
grouped_shuffling=grouped_shuffling,
)

if can_reuse_epoch_itr:
Expand Down
9 changes: 9 additions & 0 deletions fairseq/tasks/translation_multi_simple_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ def get_batch_iterator(
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
Expand Down Expand Up @@ -381,6 +383,12 @@ def get_batch_iterator(
disable_iterator_cache (bool, optional): don't cache the
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
(default: False).
grouped_shuffling (bool, optional): group batches with each groups
containing num_shards batches and shuffle groups. Reduces difference
between sequence lengths among workers for batches sorted by length.
update_epoch_batch_itr (bool optional): if true then donot use the cached
batch iterator for the epoch
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
Expand All @@ -404,6 +412,7 @@ def get_batch_iterator(
epoch=epoch,
data_buffer_size=data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
update_epoch_batch_itr=update_epoch_batch_itr,
)
self.dataset_to_epoch_iter[dataset] = batch_iter
return batch_iter
Expand Down
6 changes: 5 additions & 1 deletion fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,13 +639,17 @@ def get_train_iterator(
),
ignore_invalid_inputs=True,
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
seed=self.cfg.common.seed,
seed=(self.cfg.common.seed + epoch)
if self.cfg.dataset.update_ordered_indices_seed
else self.cfg.common.seed,
num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
shard_id=self.data_parallel_rank if shard_batch_itr else 0,
num_workers=self.cfg.dataset.num_workers,
epoch=epoch,
data_buffer_size=self.cfg.dataset.data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
grouped_shuffling=self.cfg.dataset.grouped_shuffling,
update_epoch_batch_itr=self.cfg.dataset.update_epoch_batch_itr,
)
self.reset_dummy_batch(batch_iterator.first_batch)
return batch_iterator
Expand Down

0 comments on commit 89ec6e7

Please sign in to comment.