-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathasync_preprocessing.py
183 lines (158 loc) · 5.36 KB
/
async_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from multiprocessing import Pool
import torch.utils
import torch.utils.data
from data_utils import indexed_dataset
import torch
import os
import re
import pdb
from data_utils.tokenization import BertWordPieceTokenizer
key_word = {
"…":"...",
"—":"-",
"“":"\"",
"”":"\"",
"‘":"'",
"’":"'"
}
SPECIAL_SIGNAL = "./';,\(\)\"\"'~`''“”《》<>"
def cut_sentence(paragraph):
paragraph = paragraph.replace(" ", "")
sentences = re.split('(。|!|\!|?|\?)',paragraph) # 保留分割符
if len(sentences) == 1:
return [sentences[0]]
new_sents = []
for i in range(int(len(sentences)/2)):
sent = sentences[2*i] + sentences[2*i+1]
if len(new_sents) != 0 and (sent[0] in SPECIAL_SIGNAL or len(new_sents[-1]) < 20):
new_sents[-1] += sent
else:
new_sents.append(sent)
sent = sentences[-1]
if len(sentences) % 2 == 1 and len(sent) > 0:
if len(new_sents) != 0 and (sent[0] in SPECIAL_SIGNAL or len(new_sents[-1]) < 20):
new_sents[-1] += sent
else:
new_sents.append(sent)
return new_sents
def replace_text(text):
for key,value in key_word.items():
text = re.sub(key, value, text)
return text
def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
def read_split(
filename, tokenizer, worker_id, num_workers, type_doc, min_lens=10
):
with open(filename, 'r') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
result = []
line = f.readline()
while line:
line = replace_text(line)
ids = tokenizer.convert_text_to_ids(line)
ids = ids[:509]
if len(ids) >= min_lens:
ids = [type_doc]+ids
result.append(ids)
if f.tell() > end:
break
line = f.readline()
return result
def merge_multi_line(
filename, tokenizer, worker_id, num_workers, type_doc, min_lens=10
):
with open(filename, 'r') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
eos_id = tokenizer.eos()
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
result = []
line = f.readline()
tmp_ids = []
while line:
line = replace_text(line)
ids = tokenizer.convert_text_to_ids(line)+[eos_id]
# tmp_ids.extend(ids)
if len(tmp_ids) + len(ids) > 511:
ids_cur = tmp_ids[:511]
if ids_cur[0] == eos_id:
ids_cur[0] = type_doc
else:
ids_cur = [type_doc] + ids_cur
if ids_cur[-1] == eos_id:
ids_cur.pop()
ids_cur = ids_cur[:511]
result.append(ids_cur)
tmp_ids = tmp_ids[511:]
if len(tmp_ids) + len(ids) < 511:
tmp_ids += ids
else:
tmp_ids = ids[-511:]
else:
tmp_ids.extend(ids)
if f.tell() > end:
break
line = f.readline()
return result
def main_multi_task(args):
from argparse import ArgumentParser
parser = ArgumentParser()
# parser.add_argument("--tokenizer", type=str, help="where to load vocabulary")
parser.add_argument("--data", type=str)
parser.add_argument("--out", type=str, help="output path")
parser.add_argument("--prefix", type=str, default="train")
parser.add_argument("--workers", type=int, default=6)
parser.add_argument("--task", type=str, choices=['single', 'multi'], default="single")
args = parser.parse_args(args)
tokenizer = BertWordPieceTokenizer("bert-base-chinese", cache_dir="temp_cache_dir")
data_bin = os.path.join(args.out, "{}-CLM.bin".format(args.prefix))
data_idx = os.path.join(args.out, "{}-CLM.idx".format(args.prefix))
data_ds = indexed_dataset.IndexedDatasetBuilder(data_bin)
def comsume(worker_result):
for ids in worker_result:
data_ds.add_item(torch.IntTensor(ids)
)
pool = Pool(processes=args.workers)
worker_result = []
if args.task == "single":
handle_func = read_split
elif args.task == "multi":
handle_func = merge_multi_line
for i in range(args.workers):
w = pool.apply_async(
handle_func,
(
args.data,
tokenizer,
i,
args.workers,
0,
10
),
callback=comsume
)
worker_result.append(w)
pool.close()
pool.join()
data_ds.finalize(data_idx)
print("| write data into {}".format(args.out))
if __name__ == "__main__":
import sys
main_multi_task(sys.argv[1:])