From 89ec6e7efff867d258947acafc57189b257212d0 Mon Sep 17 00:00:00 2001 From: Apoorv Vyas Date: Tue, 16 Nov 2021 13:52:15 -0800 Subject: [PATCH] Add "grouped_shuffling" for batch shuffling in groups of total workers (#2391) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: - Allows for faster training on multiple GPUs when batches are based on sorted input sequences. - Instead of shuffling batches randomly followed by distribution on workers, we group the batches in sets of total workers and then shuffle the groups. When the batches are sorted by length this ensures that each worker receives similar length inputs. # Before submitting - [N] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [Y] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [N] Did you write any new necessary tests? ## What does this PR do? Adds option "grouped_shuffling" to the dataclass to allow batches to be first grouped in set of total workers followed by shuffling of the groups. This reduces the sequence length discrepancy among the workers when the batches were created from inputs sorted by sequence lengths. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2391 Reviewed By: arbabu123 Differential Revision: D31352971 Pulled By: alexeib fbshipit-source-id: c045bedecb03339c8eb46e7e8c9804a53b35615b --- examples/MMPT/mmpt/tasks/fairseqmmtask.py | 5 +++- examples/laser/laser_src/laser_task.py | 2 ++ .../tasks/speech_text_joint.py | 3 +++ fairseq/data/iterators.py | 24 ++++++++++++++++--- fairseq/dataclass/configs.py | 18 ++++++++++++++ fairseq/options.py | 7 ++++++ fairseq/tasks/fairseq_task.py | 15 ++++++++++-- .../tasks/translation_multi_simple_epoch.py | 9 +++++++ fairseq/trainer.py | 6 ++++- 9 files changed, 82 insertions(+), 7 deletions(-) diff --git a/examples/MMPT/mmpt/tasks/fairseqmmtask.py b/examples/MMPT/mmpt/tasks/fairseqmmtask.py index fa7dae7a6c..78ef7ba17c 100644 --- a/examples/MMPT/mmpt/tasks/fairseqmmtask.py +++ b/examples/MMPT/mmpt/tasks/fairseqmmtask.py @@ -68,6 +68,8 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): random.seed(epoch) if dataset.mmdataset.split == "train" \ @@ -81,7 +83,8 @@ def get_batch_iterator( dataset, max_tokens, max_sentences, max_positions, ignore_invalid_inputs, required_batch_size_multiple, seed, num_shards, shard_id, num_workers, epoch, - data_buffer_size, disable_iterator_cache) + data_buffer_size, disable_iterator_cache, + grouped_shuffling, update_epoch_batch_itr) @property def source_dictionary(self): diff --git a/examples/laser/laser_src/laser_task.py b/examples/laser/laser_src/laser_task.py index e4152fde68..43416e0a0d 100644 --- a/examples/laser/laser_src/laser_task.py +++ b/examples/laser/laser_src/laser_task.py @@ -282,6 +282,8 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): assert isinstance(dataset, OrderedDict) diff --git a/examples/speech_text_joint_to_text/tasks/speech_text_joint.py b/examples/speech_text_joint_to_text/tasks/speech_text_joint.py index f2b3966d2d..800ccd782a 100644 --- a/examples/speech_text_joint_to_text/tasks/speech_text_joint.py +++ b/examples/speech_text_joint_to_text/tasks/speech_text_joint.py @@ -326,6 +326,8 @@ def get_batch_iterator( epoch=0, data_buffer_size=0, disable_iterator_cache=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): if not isinstance(dataset, MultiModalityDataset): @@ -343,6 +345,7 @@ def get_batch_iterator( epoch, data_buffer_size, disable_iterator_cache, + update_epoch_batch_itr=update_epoch_batch_itr, ) mult_ratio = [self.args.speech_sample_ratio, self.args.text_sample_ratio] diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 1ce26e57e5..14b4f83330 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -55,8 +55,10 @@ def __next__(self): try: x = next(self._itr) except StopIteration: - raise IndexError(f"Iterator expected to have length {self.total}, " - "but exhausted at position {self.n}.") + raise IndexError( + f"Iterator expected to have length {self.total}, " + "but exhausted at position {self.n}." + ) self.n += 1 return x @@ -263,6 +265,9 @@ class EpochBatchIterator(EpochBatchIterating): from workers. Should always be non-negative (default: ``0``). disable_shuffling (bool, optional): force disable shuffling (default: ``False``). + grouped_shuffling (bool, optional): enable shuffling batches in groups + of num_shards. Ensures that each GPU receives similar length sequences when + batches are sorted by length. """ def __init__( @@ -278,6 +283,7 @@ def __init__( buffer_size=0, timeout=0, disable_shuffling=False, + grouped_shuffling=False, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -295,6 +301,7 @@ def __init__( self.buffer_size = min(buffer_size, 20) self.timeout = timeout self.disable_shuffling = disable_shuffling + self.grouped_shuffling = grouped_shuffling self.epoch = max(epoch, 1) # we use 1-based indexing for epochs self.shuffle = not disable_shuffling @@ -433,7 +440,17 @@ def _get_iterator_for_epoch( ): def shuffle_batches(batches, seed): with data_utils.numpy_seed(seed): - np.random.shuffle(batches) + + if self.grouped_shuffling: + grouped_batches = [ + batches[(i * self.num_shards) : ((i + 1) * self.num_shards)] + for i in range((len(batches) // self.num_shards)) + ] + np.random.shuffle(grouped_batches) + batches = list(itertools.chain(*grouped_batches)) + else: + np.random.shuffle(batches) + return batches if self._supports_prefetch: @@ -639,6 +656,7 @@ def __next__(self): raise StopIteration() return item + class GroupedEpochBatchIterator(EpochBatchIterator): """Grouped version of EpochBatchIterator It takes several samplers from different datasets. diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 289bb8896b..fcf2678244 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -540,6 +540,24 @@ class DatasetConfig(FairseqDataclass): shard_id: int = field( default=0, metadata={"help": "id of the shard to generate (id < num_shards)"} ) + grouped_shuffling: bool = field( + default=False, + metadata={ + "help": "shuffle batches in groups of num_shards to enable similar sequence lengths on each GPU worker when batches are sorted by length", + }, + ) + update_epoch_batch_itr: bool = field( + default=II("dataset.grouped_shuffling"), + metadata={ + "help": "if true then prevents the reuse the epoch batch iterator by setting can_reuse_epoch_itr to false, defaults to --grouped-shuffling )", + }, + ) + update_ordered_indices_seed: bool = field( + default=False, + metadata={ + "help": "if true then increment seed with epoch for getting batch iterators, defautls to False.", + } + ) @dataclass diff --git a/fairseq/options.py b/fairseq/options.py index b4d350f902..920591635a 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -208,6 +208,13 @@ def parse_args_and_arch( else: args.no_seed_provided = False + if getattr(args, "update_epoch_batch_itr", None) is None: + if hasattr(args, "grouped_shuffling"): + args.update_epoch_batch_itr = args.grouped_shuffling + else: + args.grouped_shuffling = False + args.update_epoch_batch_itr = False + # Apply architecture configuration. if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: ARCH_CONFIG_REGISTRY[args.arch](args) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 8148c77fe1..ec62464f03 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -220,6 +220,8 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): """ Get an iterator that yields batches of data from the given dataset. @@ -252,12 +254,20 @@ def get_batch_iterator( disable_iterator_cache (bool, optional): don't cache the EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) (default: False). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch + Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ - can_reuse_epoch_itr = not disable_iterator_cache and self.can_reuse_epoch_itr( - dataset + can_reuse_epoch_itr = ( + not disable_iterator_cache + and not update_epoch_batch_itr + and self.can_reuse_epoch_itr(dataset) ) if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch)) @@ -297,6 +307,7 @@ def get_batch_iterator( num_workers=num_workers, epoch=epoch, buffer_size=data_buffer_size, + grouped_shuffling=grouped_shuffling, ) if can_reuse_epoch_itr: diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index 6f36e5b93e..f4797f5676 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -349,6 +349,8 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): """ Get an iterator that yields batches of data from the given dataset. @@ -381,6 +383,12 @@ def get_batch_iterator( disable_iterator_cache (bool, optional): don't cache the EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) (default: False). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch + Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split @@ -404,6 +412,7 @@ def get_batch_iterator( epoch=epoch, data_buffer_size=data_buffer_size, disable_iterator_cache=disable_iterator_cache, + update_epoch_batch_itr=update_epoch_batch_itr, ) self.dataset_to_epoch_iter[dataset] = batch_iter return batch_iter diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 94130c8c3a..6413411604 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -639,13 +639,17 @@ def get_train_iterator( ), ignore_invalid_inputs=True, required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, - seed=self.cfg.common.seed, + seed=(self.cfg.common.seed + epoch) + if self.cfg.dataset.update_ordered_indices_seed + else self.cfg.common.seed, num_shards=self.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.data_parallel_rank if shard_batch_itr else 0, num_workers=self.cfg.dataset.num_workers, epoch=epoch, data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, + grouped_shuffling=self.cfg.dataset.grouped_shuffling, + update_epoch_batch_itr=self.cfg.dataset.update_epoch_batch_itr, ) self.reset_dummy_batch(batch_iterator.first_batch) return batch_iterator