diff --git a/examples/MMPT/mmpt/tasks/fairseqmmtask.py b/examples/MMPT/mmpt/tasks/fairseqmmtask.py index fa7dae7a6c..78ef7ba17c 100644 --- a/examples/MMPT/mmpt/tasks/fairseqmmtask.py +++ b/examples/MMPT/mmpt/tasks/fairseqmmtask.py @@ -68,6 +68,8 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): random.seed(epoch) if dataset.mmdataset.split == "train" \ @@ -81,7 +83,8 @@ def get_batch_iterator( dataset, max_tokens, max_sentences, max_positions, ignore_invalid_inputs, required_batch_size_multiple, seed, num_shards, shard_id, num_workers, epoch, - data_buffer_size, disable_iterator_cache) + data_buffer_size, disable_iterator_cache, + grouped_shuffling, update_epoch_batch_itr) @property def source_dictionary(self): diff --git a/examples/laser/laser_src/laser_task.py b/examples/laser/laser_src/laser_task.py index e4152fde68..43416e0a0d 100644 --- a/examples/laser/laser_src/laser_task.py +++ b/examples/laser/laser_src/laser_task.py @@ -282,6 +282,8 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): assert isinstance(dataset, OrderedDict) diff --git a/examples/speech_text_joint_to_text/tasks/speech_text_joint.py b/examples/speech_text_joint_to_text/tasks/speech_text_joint.py index f2b3966d2d..800ccd782a 100644 --- a/examples/speech_text_joint_to_text/tasks/speech_text_joint.py +++ b/examples/speech_text_joint_to_text/tasks/speech_text_joint.py @@ -326,6 +326,8 @@ def get_batch_iterator( epoch=0, data_buffer_size=0, disable_iterator_cache=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): if not isinstance(dataset, MultiModalityDataset): @@ -343,6 +345,7 @@ def get_batch_iterator( epoch, data_buffer_size, disable_iterator_cache, + update_epoch_batch_itr=update_epoch_batch_itr, ) mult_ratio = [self.args.speech_sample_ratio, self.args.text_sample_ratio] diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 1ce26e57e5..14b4f83330 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -55,8 +55,10 @@ def __next__(self): try: x = next(self._itr) except StopIteration: - raise IndexError(f"Iterator expected to have length {self.total}, " - "but exhausted at position {self.n}.") + raise IndexError( + f"Iterator expected to have length {self.total}, " + "but exhausted at position {self.n}." + ) self.n += 1 return x @@ -263,6 +265,9 @@ class EpochBatchIterator(EpochBatchIterating): from workers. Should always be non-negative (default: ``0``). disable_shuffling (bool, optional): force disable shuffling (default: ``False``). + grouped_shuffling (bool, optional): enable shuffling batches in groups + of num_shards. Ensures that each GPU receives similar length sequences when + batches are sorted by length. """ def __init__( @@ -278,6 +283,7 @@ def __init__( buffer_size=0, timeout=0, disable_shuffling=False, + grouped_shuffling=False, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -295,6 +301,7 @@ def __init__( self.buffer_size = min(buffer_size, 20) self.timeout = timeout self.disable_shuffling = disable_shuffling + self.grouped_shuffling = grouped_shuffling self.epoch = max(epoch, 1) # we use 1-based indexing for epochs self.shuffle = not disable_shuffling @@ -433,7 +440,17 @@ def _get_iterator_for_epoch( ): def shuffle_batches(batches, seed): with data_utils.numpy_seed(seed): - np.random.shuffle(batches) + + if self.grouped_shuffling: + grouped_batches = [ + batches[(i * self.num_shards) : ((i + 1) * self.num_shards)] + for i in range((len(batches) // self.num_shards)) + ] + np.random.shuffle(grouped_batches) + batches = list(itertools.chain(*grouped_batches)) + else: + np.random.shuffle(batches) + return batches if self._supports_prefetch: @@ -639,6 +656,7 @@ def __next__(self): raise StopIteration() return item + class GroupedEpochBatchIterator(EpochBatchIterator): """Grouped version of EpochBatchIterator It takes several samplers from different datasets. diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 289bb8896b..fcf2678244 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -540,6 +540,24 @@ class DatasetConfig(FairseqDataclass): shard_id: int = field( default=0, metadata={"help": "id of the shard to generate (id < num_shards)"} ) + grouped_shuffling: bool = field( + default=False, + metadata={ + "help": "shuffle batches in groups of num_shards to enable similar sequence lengths on each GPU worker when batches are sorted by length", + }, + ) + update_epoch_batch_itr: bool = field( + default=II("dataset.grouped_shuffling"), + metadata={ + "help": "if true then prevents the reuse the epoch batch iterator by setting can_reuse_epoch_itr to false, defaults to --grouped-shuffling )", + }, + ) + update_ordered_indices_seed: bool = field( + default=False, + metadata={ + "help": "if true then increment seed with epoch for getting batch iterators, defautls to False.", + } + ) @dataclass diff --git a/fairseq/options.py b/fairseq/options.py index b4d350f902..920591635a 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -208,6 +208,13 @@ def parse_args_and_arch( else: args.no_seed_provided = False + if getattr(args, "update_epoch_batch_itr", None) is None: + if hasattr(args, "grouped_shuffling"): + args.update_epoch_batch_itr = args.grouped_shuffling + else: + args.grouped_shuffling = False + args.update_epoch_batch_itr = False + # Apply architecture configuration. if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: ARCH_CONFIG_REGISTRY[args.arch](args) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 8148c77fe1..ec62464f03 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -220,6 +220,8 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): """ Get an iterator that yields batches of data from the given dataset. @@ -252,12 +254,20 @@ def get_batch_iterator( disable_iterator_cache (bool, optional): don't cache the EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) (default: False). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch + Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ - can_reuse_epoch_itr = not disable_iterator_cache and self.can_reuse_epoch_itr( - dataset + can_reuse_epoch_itr = ( + not disable_iterator_cache + and not update_epoch_batch_itr + and self.can_reuse_epoch_itr(dataset) ) if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch)) @@ -297,6 +307,7 @@ def get_batch_iterator( num_workers=num_workers, epoch=epoch, buffer_size=data_buffer_size, + grouped_shuffling=grouped_shuffling, ) if can_reuse_epoch_itr: diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index 6f36e5b93e..f4797f5676 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -349,6 +349,8 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): """ Get an iterator that yields batches of data from the given dataset. @@ -381,6 +383,12 @@ def get_batch_iterator( disable_iterator_cache (bool, optional): don't cache the EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) (default: False). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch + Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split @@ -404,6 +412,7 @@ def get_batch_iterator( epoch=epoch, data_buffer_size=data_buffer_size, disable_iterator_cache=disable_iterator_cache, + update_epoch_batch_itr=update_epoch_batch_itr, ) self.dataset_to_epoch_iter[dataset] = batch_iter return batch_iter diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 94130c8c3a..6413411604 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -639,13 +639,17 @@ def get_train_iterator( ), ignore_invalid_inputs=True, required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, - seed=self.cfg.common.seed, + seed=(self.cfg.common.seed + epoch) + if self.cfg.dataset.update_ordered_indices_seed + else self.cfg.common.seed, num_shards=self.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.data_parallel_rank if shard_batch_itr else 0, num_workers=self.cfg.dataset.num_workers, epoch=epoch, data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, + grouped_shuffling=self.cfg.dataset.grouped_shuffling, + update_epoch_batch_itr=self.cfg.dataset.update_epoch_batch_itr, ) self.reset_dummy_batch(batch_iterator.first_batch) return batch_iterator