Skip to content

Commit 02fc1c7

Browse files
committed
Update evalution script.
Delete 'None' label from yesno classification. Change-Id: I7a3580e001ee388d0cbbe0d79a48d5e68c065d61
1 parent 39a2b4a commit 02fc1c7

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

utils/dureader_eval.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from rouge_metric.rouge import Rouge
3030

3131
EMPTY = ''
32-
YESNO_LABELS = set(['None', 'Yes', 'No', 'Depends'])
32+
YESNO_LABELS = set(['Yes', 'No', 'Depends'])
3333

3434

3535
def normalize(s):
@@ -62,19 +62,18 @@ def data_check(obj, task):
6262
assert 'question_type' in obj, \
6363
"Missing 'question_type' field. question_id: {}".format(obj['question_type'])
6464

65-
#if task == 'yesno' and obj['question_type'] == 'YES_NO':
6665
assert 'yesno_answers' in obj, \
6766
"Missing 'yesno_answers' field. question_id: {}".format(obj['question_id'])
6867
assert isinstance(obj['yesno_answers'], list), \
6968
r"""'yesno_answers' field must be a list, if the 'question_type' is not
7069
'YES_NO', then this field should be an empty list.
7170
question_id: {}""".format(obj['question_id'])
7271

73-
#if task == 'entity' and obj['question_type'] == 'ENTITY':
74-
assert 'entities' in obj, \
75-
"Missing 'entities' field. question_id: {}".format(obj['question_id'])
76-
assert isinstance(obj['entities'], list) and len(obj['entities']) > 0, \
77-
r"""'entities' field must be a list, and has at least one element,
72+
assert 'entity_answers' in obj, \
73+
"Missing 'entity_answers' field. question_id: {}".format(obj['question_id'])
74+
assert isinstance(obj['entity_answers'], list) \
75+
and len(obj['entity_answers']) > 0, \
76+
r"""'entity_answers' field must be a list, and has at least one element,
7877
which can be a empty list. question_id: {}""".format(obj['question_id'])
7978

8079

@@ -89,10 +88,10 @@ def read_file(file_name, task, is_ref=False):
8988
Returns:
9089
A dictionary mapping question_id to the result information. The result
9190
information itself is also a dictionary with has four keys:
92-
- question_type: type of the question.
91+
- question_type: type of the query.
9392
- yesno_answers: A list of yesno answers corresponding to 'answers'.
9493
- answers: A list of predicted answers.
95-
- entities: A list, each element is also a list containing the entities
94+
- entity_answers: A list, each element is also a list containing the entities
9695
tagged out from the corresponding answer string.
9796
"""
9897
def _open(file_name, mode, zip_obj=None):
@@ -101,7 +100,7 @@ def _open(file_name, mode, zip_obj=None):
101100
return open(file_name, mode)
102101

103102
results = {}
104-
keys = ['answers', 'yesno_answers', 'entities', 'question_type']
103+
keys = ['answers', 'yesno_answers', 'entity_answers', 'question_type']
105104
if is_ref:
106105
keys += ['source']
107106

@@ -194,8 +193,8 @@ def prepare_prf(pred_dict, ref_dict):
194193
"""
195194
Prepares data for calculation of prf scores.
196195
"""
197-
preds = {k: v['entities'] for k, v in pred_dict.items()}
198-
refs = {k: v['entities'] for k, v in ref_dict.items()}
196+
preds = {k: v['entity_answers'] for k, v in pred_dict.items()}
197+
refs = {k: v['entity_answers'] for k, v in ref_dict.items()}
199198
return preds, refs
200199

201200

@@ -477,7 +476,7 @@ def format_metrics(metrics, task, err_msg):
477476
if err_msg is not None:
478477
return {'errorMsg': str(err_msg), 'errorCode': 1, 'data': []}
479478
data = []
480-
if task != 'all':
479+
if task != 'all' and task != 'main':
481480
sources = ["both"]
482481

483482
if task == 'all':

0 commit comments

Comments
 (0)