forked from baidu/DuReader
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparagraph_extraction.py
199 lines (189 loc) · 7.24 KB
/
paragraph_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#!/usr/bin/python
#-*- coding:utf-8 -*-
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
import json
import copy
from preprocess import metric_max_over_ground_truths, f1_score
def compute_paragraph_score(sample):
"""
For each paragraph, compute the f1 score compared with the question
Args:
sample: a sample in the dataset.
Returns:
None
Raises:
None
"""
question = sample["segmented_question"]
for doc in sample['documents']:
doc['segmented_paragraphs_scores'] = []
for p_idx, para_tokens in enumerate(doc['segmented_paragraphs']):
if len(question) > 0:
related_score = metric_max_over_ground_truths(f1_score,
para_tokens,
[question])
else:
related_score = 0.0
doc['segmented_paragraphs_scores'].append(related_score)
def dup_remove(doc):
"""
For each document, remove the duplicated paragraphs
Args:
doc: a doc in the sample
Returns:
bool
Raises:
None
"""
paragraphs_his = {}
del_ids = []
para_id = None
if 'most_related_para' in doc:
para_id = doc['most_related_para']
doc['paragraphs_length'] = []
for p_idx, (segmented_paragraph, paragraph_score) in \
enumerate(zip(doc["segmented_paragraphs"], doc["segmented_paragraphs_scores"])):
doc['paragraphs_length'].append(len(segmented_paragraph))
paragraph = ''.join(segmented_paragraph)
if paragraph in paragraphs_his:
del_ids.append(p_idx)
if p_idx == para_id:
para_id = paragraphs_his[paragraph]
continue
paragraphs_his[paragraph] = p_idx
# delete
prev_del_num = 0
del_num = 0
for p_idx in del_ids:
if p_idx < para_id:
prev_del_num += 1
del doc["segmented_paragraphs"][p_idx - del_num]
del doc["segmented_paragraphs_scores"][p_idx - del_num]
del doc['paragraphs_length'][p_idx - del_num]
del_num += 1
if len(del_ids) != 0:
if 'most_related_para' in doc:
doc['most_related_para'] = para_id - prev_del_num
doc['paragraphs'] = []
for segmented_para in doc["segmented_paragraphs"]:
paragraph = ''.join(segmented_para)
doc['paragraphs'].append(paragraph)
return True
else:
return False
def paragraph_selection(sample, mode):
"""
For each document, select paragraphs that includes as much information as possible
Args:
sample: a sample in the dataset.
mode: string of ("train", "dev", "test"), indicate the type of dataset to process.
Returns:
None
Raises:
None
"""
# predefined maximum length of paragraph
MAX_P_LEN = 500
# predefined splitter
splitter = u'<splitter>'
# topN of related paragraph to choose
topN = 3
doc_id = None
if 'answer_docs' in sample and len(sample['answer_docs']) > 0:
doc_id = sample['answer_docs'][0]
if doc_id >= len(sample['documents']):
# Data error, answer doc ID > number of documents, this sample
# will be filtered by dataset.py
return
for d_idx, doc in enumerate(sample['documents']):
if 'segmented_paragraphs_scores' not in doc:
continue
status = dup_remove(doc)
segmented_title = doc["segmented_title"]
title_len = len(segmented_title)
para_id = None
if doc_id is not None:
para_id = sample['documents'][doc_id]['most_related_para']
total_len = title_len + sum(doc['paragraphs_length'])
# add splitter
para_num = len(doc["segmented_paragraphs"])
total_len += para_num
if total_len <= MAX_P_LEN:
incre_len = title_len
total_segmented_content = copy.deepcopy(segmented_title)
for p_idx, segmented_para in enumerate(doc["segmented_paragraphs"]):
if doc_id == d_idx and para_id > p_idx:
incre_len += len([splitter] + segmented_para)
if doc_id == d_idx and para_id == p_idx:
incre_len += 1
total_segmented_content += [splitter] + segmented_para
if doc_id == d_idx:
answer_start = incre_len + sample['answer_spans'][0][0]
answer_end = incre_len + sample['answer_spans'][0][1]
sample['answer_spans'][0][0] = answer_start
sample['answer_spans'][0][1] = answer_end
doc["segmented_paragraphs"] = [total_segmented_content]
doc["segmented_paragraphs_scores"] = [1.0]
doc['paragraphs_length'] = [total_len]
doc['paragraphs'] = [''.join(total_segmented_content)]
doc['most_related_para'] = 0
continue
# find topN paragraph id
para_infos = []
for p_idx, (para_tokens, para_scores) in \
enumerate(zip(doc['segmented_paragraphs'], doc['segmented_paragraphs_scores'])):
para_infos.append((para_tokens, para_scores, len(para_tokens), p_idx))
para_infos.sort(key=lambda x: (-x[1], x[2]))
topN_idx = []
for para_info in para_infos[:topN]:
topN_idx.append(para_info[-1])
final_idx = []
total_len = title_len
if doc_id == d_idx:
if mode == "train":
final_idx.append(para_id)
total_len = title_len + 1 + doc['paragraphs_length'][para_id]
for id in topN_idx:
if total_len > MAX_P_LEN:
break
if doc_id == d_idx and id == para_id and mode == "train":
continue
total_len += 1 + doc['paragraphs_length'][id]
final_idx.append(id)
total_segmented_content = copy.deepcopy(segmented_title)
final_idx.sort()
incre_len = title_len
for id in final_idx:
if doc_id == d_idx and id < para_id:
incre_len += 1 + doc['paragraphs_length'][id]
if doc_id == d_idx and id == para_id:
incre_len += 1
total_segmented_content += [splitter] + doc['segmented_paragraphs'][id]
if doc_id == d_idx:
answer_start = incre_len + sample['answer_spans'][0][0]
answer_end = incre_len + sample['answer_spans'][0][1]
sample['answer_spans'][0][0] = answer_start
sample['answer_spans'][0][1] = answer_end
doc["segmented_paragraphs"] = [total_segmented_content]
doc["segmented_paragraphs_scores"] = [1.0]
doc['paragraphs_length'] = [total_len]
doc['paragraphs'] = [''.join(total_segmented_content)]
doc['most_related_para'] = 0
if __name__ == "__main__":
# mode="train"/"dev"/"test"
mode = sys.argv[1]
for line in sys.stdin:
line = line.strip()
if line == "":
continue
try:
sample = json.loads(line, encoding='utf8')
except:
print >>sys.stderr, "Invalid input json format - '{}' will be ignored".format(line)
continue
compute_paragraph_score(sample)
paragraph_selection(sample, mode)
print(json.dumps(sample, encoding='utf8', ensure_ascii=False))