forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_dualinput_s2t_transformer.py
110 lines (97 loc) · 3.88 KB
/
test_dualinput_s2t_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from argparse import Namespace
from collections import namedtuple
from pathlib import Path
import torch
from tqdm import tqdm
import fairseq
from fairseq import utils
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.scoring.bleu import SacrebleuScorer
from fairseq.tasks import import_tasks
from tests.speech import TestFairseqSpeech
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestDualInputS2TTransformer(TestFairseqSpeech):
def setUp(self):
self.set_up_mustc_de_fbank()
def import_user_module(self):
user_dir = (
Path(fairseq.__file__).parent.parent / "examples/speech_text_joint_to_text"
)
Arg = namedtuple("Arg", ["user_dir"])
arg = Arg(user_dir.__str__())
utils.import_user_module(arg)
@torch.no_grad()
def test_mustc_de_fbank_dualinput_s2t_transformer_checkpoint(self):
self.import_user_module()
checkpoint_filename = "checkpoint_ave_10.pt"
path = self.download(self.base_url, self.root, checkpoint_filename)
models, cfg, task = load_model_ensemble_and_task(
[path.as_posix()],
arg_overrides={
"data": self.root.as_posix(),
"config_yaml": "config.yaml",
"load_pretrain_speech_encoder": "",
"load_pretrain_text_encoder_last": "",
"load_pretrain_decoder": "",
"beam": 10,
"nbest": 1,
"lenpen": 1.0,
"load_speech_only": True,
},
)
if self.use_cuda:
for model in models:
model.cuda()
generator = task.build_generator(models, cfg)
test_split = "tst-COMMON"
task.load_dataset(test_split)
batch_iterator = task.get_batch_iterator(
dataset=task.dataset(test_split),
max_tokens=250_000,
max_positions=(10_000, 1_024),
num_workers=1,
).next_epoch_itr(shuffle=False)
tokenizer = task.build_tokenizer(cfg.tokenizer)
bpe = task.build_bpe(cfg.bpe)
def decode_fn(x):
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
scorer_args = {
"sacrebleu_tokenizer": "13a",
"sacrebleu_lowercase": False,
"sacrebleu_char_level": False,
}
scorer = SacrebleuScorer(Namespace(**scorer_args))
progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
for batch_idx, sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
hypo = task.inference_step(generator, models, sample)
for i, sample_id in enumerate(sample["id"].tolist()):
tgt_tokens = (
utils.strip_pad(sample["target"][i, :], task.tgt_dict.pad())
.int()
.cpu()
)
tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
hypo_str = task.tgt_dict.string(
hypo[i][0]["tokens"].int().cpu(), "sentencepiece"
)
if batch_idx == 0 and i < 3:
print(f"T-{sample_id} {tgt_str}")
print(f"D-{sample_id} {hypo_str}")
scorer.add_string(tgt_str, hypo_str)
reference_bleu = 27.3
result = scorer.result_string()
print(result + f" (reference: {reference_bleu})")
res_bleu = float(result.split()[2])
self.assertAlmostEqual(res_bleu, reference_bleu, delta=0.3)
if __name__ == "__main__":
unittest.main()