-
Notifications
You must be signed in to change notification settings - Fork 0
/
translate_adapter.py
159 lines (131 loc) · 6.8 KB
/
translate_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) 2019-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Translate sentences from the input stream.
# The model will be faster is sentences are sorted by length.
# Input sentences must have the same tokenization and BPE codes than the ones used in the model.
#
# Usage:
# cat source_sentences.bpe | \
# python translate.py --exp_name translate \
# --src_lang en --tgt_lang fr \
# --model_path trained_model.pth --output_path output
#
import os
import io
import argparse
import torch
from src.utils import AttrDict
from src.utils import bool_flag, initialize_exp
from src.data.dictionary import Dictionary
from src.model.transformer import TransformerModel
def get_parser():
"""
Generate a parameters parser.
"""
# parse parameters
parser = argparse.ArgumentParser(description="Translate sentences")
# main parameters
parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path")
parser.add_argument("--exp_name", type=str, default="", help="Experiment name")
parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch")
# model / output paths
parser.add_argument("--model_path", type=str, default="", help="Model path")
parser.add_argument("--output_path", type=str, default="", help="Output path")
parser.add_argument("--sentences_path", type=str, default="", help="Sentences to translate path")
# source language / target language
parser.add_argument("--src_lang", type=str, default="", help="Source language")
parser.add_argument("--tgt_lang", type=str, default="", help="Target language")
parser.add_argument("--early_stopping", type=bool_flag, default=False,
help="Early stopping, stop as soon as we have `beam_size` hypotheses, although longer ones may have better scores.")
parser.add_argument("--length_penalty", type=float, default=1,
help="Length penalty, values < 1.0 favor shorter sentences, while values > 1.0 favor longer ones.")
parser.add_argument("--beam_size", type=int, default=1,
help="Beam size, default = 1 (greedy decoding)")
return parser
def main(params):
# initialize the experiment
logger = initialize_exp(params)
# generate parser / parse parameters
parser = get_parser()
params = parser.parse_args()
reloaded = torch.load(params.model_path)
model_params = AttrDict(reloaded['params'])
logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))
# update dictionary parameters
for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
setattr(params, name, getattr(model_params, name))
# build dictionary / build encoder / build decoder / reload weights
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
encoder.load_state_dict(reloaded['encoder'])
decoder.load_state_dict(reloaded['decoder'])
params.src_id = model_params.lang2id[params.src_lang]
params.tgt_id = model_params.lang2id[params.tgt_lang]
logger.info("encoder: {}".format(encoder))
logger.info("decoder: {}".format(decoder))
# read sentences from stdin
src_sent = []
with open(params.sentences_path, 'r') as file1:
for line in file1:
src_sent.append(line)
logger.info("Read %i sentences from sentences file.Writing them to a src file. Translating ..." % len(src_sent))
f = io.open(params.output_path + 'src_sent', 'w', encoding='utf-8')
for sentence in src_sent:
f.write(sentence + "\n")
f.close()
logger.info("Wrote them to a src file")
f = io.open(params.output_path, 'w', encoding='utf-8')
for i in range(0, len(src_sent), params.batch_size):
# prepare batch
word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()])
for s in src_sent[i:i + params.batch_size]]
lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index)
batch[0] = params.eos_index
for j, s in enumerate(word_ids):
if lengths[j] > 2: # if sentence not empty
batch[1:lengths[j] - 1, j].copy_(s)
batch[lengths[j] - 1, j] = params.eos_index
langs = batch.clone().fill_(params.src_id)
# encode source batch and translate it
encoded, _ = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False,
encoder_only=False, extra_adapters_flag=True)
encoded = encoded.transpose(0, 1)
# decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
decoded, dec_lengths = decoder.generate_beam(
encoded, lengths.cuda(), params.tgt_id, beam_size=params.beam_size, length_penalty=params.length_penalty,
early_stopping=params.early_stopping, max_len=int(1.5 * lengths.cuda().max().item() + 10),
extra_adapters_flag=True)
# convert sentences to words
for j in range(decoded.size(1)):
# remove delimiters
sent = decoded[:, j]
delimiters = (sent == params.eos_index).nonzero().view(-1)
assert len(delimiters) >= 1 and delimiters[0].item() == 0
sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]
# output translation
source = src_sent[i + j].strip()
target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
#logger.info("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
if (i+j)%10000 == 0:
logger.info("Translation of %i / %i:\n Source sentence: %s \n Translation: %s\n" % (i + j, len(src_sent), source, target))
# sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
f.write(target + "\n")
f.close()
if __name__ == '__main__':
# generate parser / parse parameters
parser = get_parser()
params = parser.parse_args()
# check parameters
assert os.path.isfile(params.model_path)
assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang
assert params.output_path and not os.path.isfile(params.output_path)
# translate
with torch.no_grad():
main(params)