Skip to content

Commit

Permalink
Update evalution script.
Browse files Browse the repository at this point in the history
Delete 'None' label from yesno classification.

Change-Id: I7a3580e001ee388d0cbbe0d79a48d5e68c065d61
  • Loading branch information
liuyuuan committed Nov 13, 2017
1 parent 39a2b4a commit 02fc1c7
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions utils/dureader_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from rouge_metric.rouge import Rouge

EMPTY = ''
YESNO_LABELS = set(['None', 'Yes', 'No', 'Depends'])
YESNO_LABELS = set(['Yes', 'No', 'Depends'])


def normalize(s):
Expand Down Expand Up @@ -62,19 +62,18 @@ def data_check(obj, task):
assert 'question_type' in obj, \
"Missing 'question_type' field. question_id: {}".format(obj['question_type'])

#if task == 'yesno' and obj['question_type'] == 'YES_NO':
assert 'yesno_answers' in obj, \
"Missing 'yesno_answers' field. question_id: {}".format(obj['question_id'])
assert isinstance(obj['yesno_answers'], list), \
r"""'yesno_answers' field must be a list, if the 'question_type' is not
'YES_NO', then this field should be an empty list.
question_id: {}""".format(obj['question_id'])

#if task == 'entity' and obj['question_type'] == 'ENTITY':
assert 'entities' in obj, \
"Missing 'entities' field. question_id: {}".format(obj['question_id'])
assert isinstance(obj['entities'], list) and len(obj['entities']) > 0, \
r"""'entities' field must be a list, and has at least one element,
assert 'entity_answers' in obj, \
"Missing 'entity_answers' field. question_id: {}".format(obj['question_id'])
assert isinstance(obj['entity_answers'], list) \
and len(obj['entity_answers']) > 0, \
r"""'entity_answers' field must be a list, and has at least one element,
which can be a empty list. question_id: {}""".format(obj['question_id'])


Expand All @@ -89,10 +88,10 @@ def read_file(file_name, task, is_ref=False):
Returns:
A dictionary mapping question_id to the result information. The result
information itself is also a dictionary with has four keys:
- question_type: type of the question.
- question_type: type of the query.
- yesno_answers: A list of yesno answers corresponding to 'answers'.
- answers: A list of predicted answers.
- entities: A list, each element is also a list containing the entities
- entity_answers: A list, each element is also a list containing the entities
tagged out from the corresponding answer string.
"""
def _open(file_name, mode, zip_obj=None):
Expand All @@ -101,7 +100,7 @@ def _open(file_name, mode, zip_obj=None):
return open(file_name, mode)

results = {}
keys = ['answers', 'yesno_answers', 'entities', 'question_type']
keys = ['answers', 'yesno_answers', 'entity_answers', 'question_type']
if is_ref:
keys += ['source']

Expand Down Expand Up @@ -194,8 +193,8 @@ def prepare_prf(pred_dict, ref_dict):
"""
Prepares data for calculation of prf scores.
"""
preds = {k: v['entities'] for k, v in pred_dict.items()}
refs = {k: v['entities'] for k, v in ref_dict.items()}
preds = {k: v['entity_answers'] for k, v in pred_dict.items()}
refs = {k: v['entity_answers'] for k, v in ref_dict.items()}
return preds, refs


Expand Down Expand Up @@ -477,7 +476,7 @@ def format_metrics(metrics, task, err_msg):
if err_msg is not None:
return {'errorMsg': str(err_msg), 'errorCode': 1, 'data': []}
data = []
if task != 'all':
if task != 'all' and task != 'main':
sources = ["both"]

if task == 'all':
Expand Down

0 comments on commit 02fc1c7

Please sign in to comment.