Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Summary:
Adds ability to tag individual examples with the names of their datasets, along with some minor miscellaneous fixes and improvements
Pull Request resolved: fairinternal/fairseq-py#838

Differential Revision: D16919175

Pulled By: alexeib

fbshipit-source-id: 4bf493299645bae63f3ee6382e15f18a9f73666c
  • Loading branch information
alexeib authored and facebook-github-bot committed Aug 21, 2019
1 parent 7a31fe0 commit a2f5361
Show file tree
Hide file tree
Showing 11 changed files with 482 additions and 89 deletions.
12 changes: 10 additions & 2 deletions fairseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
from .num_samples_dataset import NumSamplesDataset
from .offset_tokens_dataset import OffsetTokensDataset
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
from .prepend_dataset import PrependDataset
from .prepend_token_dataset import PrependTokenDataset
from .raw_label_dataset import RawLabelDataset
from .replace_dataset import ReplaceDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sharded_dataset import ShardedDataset
from .sort_dataset import SortDataset
from .strip_token_dataset import StripTokenDataset
from .subsample_dataset import SubsampleDataset
from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
Expand Down Expand Up @@ -72,14 +76,18 @@
'NumSamplesDataset',
"OffsetTokensDataset",
'PadDataset',
'PrependDataset',
'PrependTokenDataset',
'RawAudioDataset',
"RawLabelDataset",
'RawLabelDataset',
'ReplaceDataset',
'RightPadDataset',
'RoundRobinZipDatasets',
'ShardedDataset',
'ShardedIterator',
'SortDataset',
"StripTokenDataset",
'StripTokenDataset',
'SubsampleDataset',
'TokenBlockDataset',
'TransformEosDataset',
'TransformEosLangPairDataset',
Expand Down
4 changes: 4 additions & 0 deletions fairseq/data/concat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def size(self, idx: int):
def num_tokens(self, index: int):
return np.max(self.size(index))

def attr(self, attr: str, index: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
return getattr(self.datasets[dataset_idx], attr, None)

@property
def sizes(self):
return np.concatenate(
Expand Down
3 changes: 3 additions & 0 deletions fairseq/data/fairseq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return False

def attr(self, attr: str, index: int):
return getattr(self, attr, None)

def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError
Expand Down
28 changes: 28 additions & 0 deletions fairseq/data/prepend_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch

from . import BaseWrapperDataset


class PrependDataset(BaseWrapperDataset):
def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
super().__init__(dataset)
self.prepend_getter = prepend_getter
self.ensure_first_token = ensure_first_token_is

def __getitem__(self, idx):
item = self.dataset[idx]
is_tuple = isinstance(item, tuple)
src = item[0] if is_tuple else item

assert self.ensure_first_token is None or src[0] == self.ensure_first_token
prepend_idx = self.prepend_getter(self.dataset, idx)
assert isinstance(prepend_idx, int)
src[0] = prepend_idx
item = tuple((src,) + item[1:]) if is_tuple else src
return item
26 changes: 26 additions & 0 deletions fairseq/data/replace_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import BaseWrapperDataset


class ReplaceDataset(BaseWrapperDataset):
def __init__(self, dataset, replace_map, offset=0):
super().__init__(dataset)
assert len(replace_map) > 0
self.replace_map = replace_map
self.offset = offset

def __getitem__(self, index):
item = self.dataset[index]
is_tuple = isinstance(item, tuple)
src = item[0] if is_tuple else item

for k, v in self.replace_map.items():
src_off = src[self.offset:]
src_off.masked_fill_(src_off == k, v)

item = tuple((src,) + item[1:]) if is_tuple else src
return item
60 changes: 60 additions & 0 deletions fairseq/data/sharded_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import os
import random

from . import BaseWrapperDataset
from fairseq.data import data_utils


class ShardedDataset(BaseWrapperDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
Loads a dataset which has been sharded into multiple files. each shard is only loaded for each specific epoch
"""

def __init__(
self,
dictionary,
dataset_impl: str,
path: str,
split: str,
epoch: int,
name: str = None,
combine: bool = False,
seed: int = 0,
):
self._name = name if name is not None else os.path.basename(path)
num_shards = 0
for i in itertools.count():
if not os.path.exists(os.path.join(path, "shard" + str(i))):
break
num_shards += 1

if num_shards > 0 and split == "train":
random.seed(seed ^ epoch)
shard = random.randint(0, num_shards - 1)
split_path = os.path.join(path, "shard" + str(shard), split)
else:
split_path = os.path.join(path, split)
if os.path.isdir(split_path):
split_path = os.path.join(split_path, split)

dataset = data_utils.load_indexed_dataset(
split_path, dictionary, dataset_impl, combine=combine
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)

super().__init__(dataset)

@property
def name(self):
return self._name
57 changes: 57 additions & 0 deletions fairseq/data/subsample_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np

from . import BaseWrapperDataset


class SubsampleDataset(BaseWrapperDataset):
def __init__(self, dataset, size_ratio):
super().__init__(dataset)
assert size_ratio < 1
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
self.indices = np.random.choice(
range(len(self.dataset)), self.actual_size, replace=False
)
print(
f"subsampled dataset from {len(self.dataset)} to {self.actual_size} (ratio={size_ratio})"
)

def __getitem__(self, index):
return self.dataset[self.indices[index]]

def __len__(self):
return self.actual_size

def collater(self, samples):
return self.dataset.collater(samples)

@property
def sizes(self):
return self.dataset.sizes[self.indices]

@property
def name(self):
return self.dataset.name

def num_tokens(self, index):
return self.dataset.num_tokens(self.indices[index])

def size(self, index):
return self.dataset.size(self.indices[index])

def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)

def prefetch(self, indices):
self.dataset.prefetch(self.indices[indices])
Loading

0 comments on commit a2f5361

Please sign in to comment.