Skip to content

Commit

Permalink
fix multi modality dataset error with resampling datasets after first…
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 authored Jul 7, 2022
1 parent 570c942 commit cba35cd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def resolve_data_type(cls, split, use_sup_speech_ctc):
"sup_speech_s2s",
"unsup_speech",
"sup_speech_ctc",
)
), f"failed resolving {split} (it resulted into: {dtype} ; is_train={is_train})"
return is_train, dtype

def create_modalitydatasetitem(self, dtype, dataset):
Expand Down
26 changes: 21 additions & 5 deletions fairseq/data/audio/multi_modality_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List, Optional, NamedTuple

import numpy as np
from fairseq.data.resampling_dataset import ResamplingDataset
import torch
from fairseq.data import (
ConcatDataset,
Expand All @@ -30,6 +31,16 @@ class ModalityDatasetItem(NamedTuple):
max_sentences: Optional[int] = None


def resampling_dataset_present(ds):
if isinstance(ds, ResamplingDataset):
return True
if isinstance(ds, ConcatDataset):
return any(resampling_dataset_present(d) for d in ds.datasets)
if hasattr(ds, "dataset"):
return resampling_dataset_present(ds.dataset)
return False


# MultiModalityDataset: it concate multiple datasets with different modalities.
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
# 2) it adds mode to indicate what type of the data samples come from.
Expand Down Expand Up @@ -88,7 +99,7 @@ def size(self, index: int):
def sizes(self):
if len(self.datasets) == 1:
return self.datasets[0].sizes
super().sizes
return super().sizes

def ordered_indices(self):
"""
Expand All @@ -106,12 +117,14 @@ def ordered_indices(self):
return indices_group

def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
if len(self.raw_sub_batch_samplers) > 0:
logger.info(" raw_sub_batch_samplers exists. No action is taken")
return
with data_utils.numpy_seed(seed):
indices = self.ordered_indices()
for i, ds in enumerate(self.datasets):
# If we have ResamplingDataset, the same id can correpond to a different
# sample in the next epoch, so we need to rebuild this at every epoch
if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(ds):
logger.info(f"dataset {i} is valid and it is not re-sampled")
continue
indices[i] = ds.filter_indices_by_size(
indices[i],
self.max_positions[i],
Expand All @@ -122,7 +135,10 @@ def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
max_sentences=self.max_sentences[i],
required_batch_size_multiple=required_batch_size_multiple,
)
self.raw_sub_batch_samplers.append(sub_batch_sampler)
if i < len(self.raw_sub_batch_samplers):
self.raw_sub_batch_samplers[i] = sub_batch_sampler
else:
self.raw_sub_batch_samplers.append(sub_batch_sampler)

def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
self.get_raw_batch_samplers(required_batch_size_multiple, seed)
Expand Down

0 comments on commit cba35cd

Please sign in to comment.