forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Adds ability to tag individual examples with the names of their datasets, along with some minor miscellaneous fixes and improvements Pull Request resolved: fairinternal/fairseq-py#838 Differential Revision: D16919175 Pulled By: alexeib fbshipit-source-id: 4bf493299645bae63f3ee6382e15f18a9f73666c
- Loading branch information
1 parent
7a31fe0
commit a2f5361
Showing
11 changed files
with
482 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from . import BaseWrapperDataset | ||
|
||
|
||
class PrependDataset(BaseWrapperDataset): | ||
def __init__(self, dataset, prepend_getter, ensure_first_token_is=None): | ||
super().__init__(dataset) | ||
self.prepend_getter = prepend_getter | ||
self.ensure_first_token = ensure_first_token_is | ||
|
||
def __getitem__(self, idx): | ||
item = self.dataset[idx] | ||
is_tuple = isinstance(item, tuple) | ||
src = item[0] if is_tuple else item | ||
|
||
assert self.ensure_first_token is None or src[0] == self.ensure_first_token | ||
prepend_idx = self.prepend_getter(self.dataset, idx) | ||
assert isinstance(prepend_idx, int) | ||
src[0] = prepend_idx | ||
item = tuple((src,) + item[1:]) if is_tuple else src | ||
return item |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from . import BaseWrapperDataset | ||
|
||
|
||
class ReplaceDataset(BaseWrapperDataset): | ||
def __init__(self, dataset, replace_map, offset=0): | ||
super().__init__(dataset) | ||
assert len(replace_map) > 0 | ||
self.replace_map = replace_map | ||
self.offset = offset | ||
|
||
def __getitem__(self, index): | ||
item = self.dataset[index] | ||
is_tuple = isinstance(item, tuple) | ||
src = item[0] if is_tuple else item | ||
|
||
for k, v in self.replace_map.items(): | ||
src_off = src[self.offset:] | ||
src_off.masked_fill_(src_off == k, v) | ||
|
||
item = tuple((src,) + item[1:]) if is_tuple else src | ||
return item |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import itertools | ||
import os | ||
import random | ||
|
||
from . import BaseWrapperDataset | ||
from fairseq.data import data_utils | ||
|
||
|
||
class ShardedDataset(BaseWrapperDataset): | ||
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS. | ||
Loads a dataset which has been sharded into multiple files. each shard is only loaded for each specific epoch | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dictionary, | ||
dataset_impl: str, | ||
path: str, | ||
split: str, | ||
epoch: int, | ||
name: str = None, | ||
combine: bool = False, | ||
seed: int = 0, | ||
): | ||
self._name = name if name is not None else os.path.basename(path) | ||
num_shards = 0 | ||
for i in itertools.count(): | ||
if not os.path.exists(os.path.join(path, "shard" + str(i))): | ||
break | ||
num_shards += 1 | ||
|
||
if num_shards > 0 and split == "train": | ||
random.seed(seed ^ epoch) | ||
shard = random.randint(0, num_shards - 1) | ||
split_path = os.path.join(path, "shard" + str(shard), split) | ||
else: | ||
split_path = os.path.join(path, split) | ||
if os.path.isdir(split_path): | ||
split_path = os.path.join(split_path, split) | ||
|
||
dataset = data_utils.load_indexed_dataset( | ||
split_path, dictionary, dataset_impl, combine=combine | ||
) | ||
if dataset is None: | ||
raise FileNotFoundError( | ||
"Dataset not found: {} ({})".format(split, split_path) | ||
) | ||
|
||
super().__init__(dataset) | ||
|
||
@property | ||
def name(self): | ||
return self._name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
|
||
from . import BaseWrapperDataset | ||
|
||
|
||
class SubsampleDataset(BaseWrapperDataset): | ||
def __init__(self, dataset, size_ratio): | ||
super().__init__(dataset) | ||
assert size_ratio < 1 | ||
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int) | ||
self.indices = np.random.choice( | ||
range(len(self.dataset)), self.actual_size, replace=False | ||
) | ||
print( | ||
f"subsampled dataset from {len(self.dataset)} to {self.actual_size} (ratio={size_ratio})" | ||
) | ||
|
||
def __getitem__(self, index): | ||
return self.dataset[self.indices[index]] | ||
|
||
def __len__(self): | ||
return self.actual_size | ||
|
||
def collater(self, samples): | ||
return self.dataset.collater(samples) | ||
|
||
@property | ||
def sizes(self): | ||
return self.dataset.sizes[self.indices] | ||
|
||
@property | ||
def name(self): | ||
return self.dataset.name | ||
|
||
def num_tokens(self, index): | ||
return self.dataset.num_tokens(self.indices[index]) | ||
|
||
def size(self, index): | ||
return self.dataset.size(self.indices[index]) | ||
|
||
def ordered_indices(self): | ||
"""Return an ordered list of indices. Batches will be constructed based | ||
on this order.""" | ||
if self.shuffle: | ||
order = [np.random.permutation(len(self))] | ||
else: | ||
order = [np.arange(len(self))] | ||
order.append(self.sizes) | ||
return np.lexsort(order) | ||
|
||
def prefetch(self, indices): | ||
self.dataset.prefetch(self.indices[indices]) |
Oops, something went wrong.