forked from grammarly/gector
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_debug.py
83 lines (70 loc) · 3.21 KB
/
predict_debug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
from utils.helpers import read_lines, normalize
from gector.gec_model import GecBERTModel
def predict_for_file(input_file, output_file, model, batch_size=32, to_normalize=False):
test_data = read_lines(input_file)
predictions = []
cnt_corrections = 0
batch = []
for sent in test_data:
batch.append(sent.split())
if len(batch) == batch_size:
preds, cnt = model.handle_batch(batch)
predictions.extend(preds)
cnt_corrections += cnt
batch = []
if batch:
preds, cnt = model.handle_batch(batch)
predictions.extend(preds)
cnt_corrections += cnt
result_lines = [" ".join(x) for x in predictions]
if to_normalize:
result_lines = [normalize(line) for line in result_lines]
with open(output_file, 'w') as f:
f.write("\n".join(result_lines) + '\n')
return cnt_corrections
def main(model_path, vocab_path, input_file, output_file, max_len=50, min_len=3, batch_size=128,
lowercase_tokens=0, transformer_model='roberta', iteration_count=5, additional_confidence=0,
additional_del_confidence=0, min_error_probability=0.0, special_tokens_fix=1, is_ensemble=0,
weights=None, normalize=False):
# get all paths
model = GecBERTModel(vocab_path=vocab_path,
model_paths=model_path,
max_len=max_len, min_len=min_len,
iterations=iteration_count,
min_error_probability=min_error_probability,
lowercase_tokens=lowercase_tokens,
model_name=transformer_model,
special_tokens_fix=special_tokens_fix,
log=False,
confidence=additional_confidence,
del_confidence=additional_del_confidence,
is_ensemble=is_ensemble,
weigths=weights)
cnt_corrections = predict_for_file(input_file, output_file, model,
batch_size=batch_size,
to_normalize=normalize)
# evaluate with m2 or ERRANT
print(f"Produced overall corrections: {cnt_corrections}")
if __name__ == '__main__':
# read parameters
model_path = ['./pretrained-models/xlnet_0_gectorv2.th']
vocab_path = './data/output_vocabulary' # replace with the actual vocabulary path
input_file = './model-io/model-input.txt' # replace with the actual input file path
output_file = './model-io/model-output.txt' # replace with the actual output file path
max_len = 50
min_len = 3
batch_size = 128
lowercase_tokens = 0
transformer_model = 'roberta'
iteration_count = 5
additional_confidence = 0
additional_del_confidence = 0
min_error_probability = 0.0
special_tokens_fix = 1
is_ensemble = 0
weights = None
normalize = False
main(model_path, vocab_path, input_file, output_file, max_len, min_len, batch_size, lowercase_tokens,
transformer_model, iteration_count, additional_confidence, additional_del_confidence, min_error_probability,
special_tokens_fix, is_ensemble, weights, normalize)