forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: liezl200 <[email protected]>
- Loading branch information
Showing
4 changed files
with
276 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
|
||
from fairseq import sequence_generator | ||
|
||
from . import FairseqDataset, language_pair_dataset | ||
|
||
|
||
class BacktranslationDataset(FairseqDataset): | ||
def __init__(self, args, tgt_dataset, tgt_dict, backtranslation_model): | ||
""" | ||
Sets up a backtranslation dataset which takes a tgt batch, generates | ||
a src using a tgt-src backtranslation_model, and returns the | ||
corresponding {generated src, input tgt} batch | ||
Args: | ||
args: generation args for the backtranslation SequenceGenerator' | ||
Note that there is no equivalent argparse code for these args | ||
anywhere in our top level train scripts yet. Integration is | ||
still in progress. You can still, however, test out this dataset | ||
functionality with the appropriate args as in the corresponding | ||
unittest: test_backtranslation_dataset. | ||
tgt_dataset: dataset which will be used to build self.tgt_dataset -- | ||
a LanguagePairDataset with tgt dataset as the source dataset and | ||
None as the target dataset. | ||
We use language_pair_dataset here to encapsulate the tgt_dataset | ||
so we can re-use the LanguagePairDataset collater to format the | ||
batches in the structure that SequenceGenerator expects. | ||
tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary) | ||
backtranslation_model: tgt-src model to use in the SequenceGenerator | ||
to generate backtranslations from tgt batches | ||
""" | ||
self.tgt_dataset = language_pair_dataset.LanguagePairDataset( | ||
src=tgt_dataset, | ||
src_sizes=None, | ||
src_dict=tgt_dict, | ||
tgt=None, | ||
tgt_sizes=None, | ||
tgt_dict=None, | ||
) | ||
self.backtranslation_generator = sequence_generator.SequenceGenerator( | ||
[backtranslation_model], | ||
tgt_dict, | ||
unk_penalty=args.backtranslation_unkpen, | ||
sampling=args.backtranslation_sampling, | ||
beam_size=args.backtranslation_beam, | ||
) | ||
self.backtranslation_max_len_a = args.backtranslation_max_len_a | ||
self.backtranslation_max_len_b = args.backtranslation_max_len_b | ||
self.backtranslation_beam = args.backtranslation_beam | ||
|
||
def __getitem__(self, index): | ||
""" | ||
Returns a single sample. Multiple samples are fed to the collater to | ||
create a backtranslation batch. Note you should always use collate_fn | ||
BacktranslationDataset.collater() below if given the option to | ||
specify which collate_fn to use (e.g. in a dataloader which uses this | ||
BacktranslationDataset -- see corresponding unittest for an example). | ||
""" | ||
return self.tgt_dataset[index] | ||
|
||
def __len__(self): | ||
""" | ||
The length of the backtranslation dataset is the length of tgt. | ||
""" | ||
return len(self.tgt_dataset) | ||
|
||
def collater(self, samples): | ||
""" | ||
Using the samples from the tgt dataset, load a collated tgt sample to | ||
feed to the backtranslation model. Then take the generated translation | ||
with best score as the source and the orignal net input as the target. | ||
""" | ||
collated_tgt_only_sample = self.tgt_dataset.collater(samples) | ||
backtranslation_hypos = self._generate_hypotheses(collated_tgt_only_sample) | ||
|
||
# Go through each tgt sentence in batch and its corresponding best | ||
# generated hypothesis and create a backtranslation data pair | ||
# {id: id, source: generated backtranslation, target: original tgt} | ||
generated_samples = [] | ||
for input_sample, hypos in zip(samples, backtranslation_hypos): | ||
generated_samples.append( | ||
{ | ||
"id": input_sample["id"], | ||
"source": hypos[0]["tokens"], # first hypo is best hypo | ||
"target": input_sample["source"], | ||
} | ||
) | ||
|
||
return language_pair_dataset.collate( | ||
samples=generated_samples, | ||
pad_idx=self.tgt_dataset.src_dict.pad(), | ||
eos_idx=self.tgt_dataset.src_dict.eos(), | ||
) | ||
|
||
def get_dummy_batch(self, num_tokens, max_positions): | ||
""" Just use the tgt dataset get_dummy_batch """ | ||
self.tgt_dataset.get_dummy_batch(num_tokens, max_positions) | ||
|
||
def num_tokens(self, index): | ||
""" Just use the tgt dataset num_tokens """ | ||
self.tgt_dataset.num_tokens(index) | ||
|
||
def ordered_indices(self): | ||
""" Just use the tgt dataset ordered_indices """ | ||
self.tgt_dataset.ordered_indices | ||
|
||
def valid_size(self, index, max_positions): | ||
""" Just use the tgt dataset size """ | ||
self.tgt_dataset.valid_size(index, max_positions) | ||
|
||
def _generate_hypotheses(self, sample): | ||
""" | ||
Generates hypotheses from a LanguagePairDataset collated / batched | ||
sample. Note in this case, sample["target"] is None, and | ||
sample["net_input"]["src_tokens"] is really in tgt language. | ||
""" | ||
self.backtranslation_generator.cuda() | ||
input = sample["net_input"] | ||
srclen = input["src_tokens"].size(1) | ||
hypos = self.backtranslation_generator.generate( | ||
input, | ||
maxlen=int( | ||
self.backtranslation_max_len_a * srclen + self.backtranslation_max_len_b | ||
), | ||
) | ||
return hypos |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
import argparse | ||
import unittest | ||
|
||
import tests.utils as test_utils | ||
import torch | ||
from fairseq.data.backtranslation_dataset import BacktranslationDataset | ||
|
||
|
||
class TestBacktranslationDataset(unittest.TestCase): | ||
def setUp(self): | ||
self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = ( | ||
test_utils.sequence_generator_setup() | ||
) | ||
backtranslation_args = argparse.Namespace() | ||
|
||
""" | ||
Same as defaults from fairseq/options.py | ||
""" | ||
backtranslation_args.backtranslation_unkpen = 0 | ||
backtranslation_args.backtranslation_sampling = False | ||
backtranslation_args.backtranslation_max_len_a = 0 | ||
backtranslation_args.backtranslation_max_len_b = 200 | ||
backtranslation_args.backtranslation_beam = 2 | ||
|
||
self.backtranslation_args = backtranslation_args | ||
|
||
dummy_src_samples = self.src_tokens | ||
|
||
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples) | ||
|
||
def test_backtranslation_dataset(self): | ||
backtranslation_dataset = BacktranslationDataset( | ||
args=self.backtranslation_args, | ||
tgt_dataset=self.tgt_dataset, | ||
tgt_dict=self.tgt_dict, | ||
backtranslation_model=self.model, | ||
) | ||
dataloader = torch.utils.data.DataLoader( | ||
backtranslation_dataset, | ||
batch_size=2, | ||
collate_fn=backtranslation_dataset.collater, | ||
) | ||
backtranslation_batch_result = next(iter(dataloader)) | ||
|
||
eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2 | ||
|
||
# Note that we sort by src_lengths and add left padding, so actually | ||
# ids will look like: [1, 0] | ||
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]]) | ||
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) | ||
generated_src = backtranslation_batch_result["net_input"]["src_tokens"] | ||
tgt_tokens = backtranslation_batch_result["target"] | ||
|
||
self.assertTensorEqual(expected_src, generated_src) | ||
self.assertTensorEqual(expected_tgt, tgt_tokens) | ||
|
||
def assertTensorEqual(self, t1, t2): | ||
self.assertEqual(t1.size(), t2.size(), "size mismatch") | ||
self.assertEqual(t1.ne(t2).long().sum(), 0) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.