forked from baidu/DuReader
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. change 'query' to 'question' 2. remove authors 3. other refinements Change-Id: Ic237e7dc6ad624a97e6c10c463018730b73c842e
- Loading branch information
Showing
15 changed files
with
378 additions
and
369 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,9 +17,6 @@ | |
""" | ||
This module implements the BiDAF algorithm described in | ||
https://arxiv.org/abs/1611.01603 | ||
Authors: liuyuan([email protected]) | ||
Date: 2017/09/20 12:00:00 | ||
""" | ||
|
||
import paddle.v2.layer as layer | ||
|
@@ -28,13 +25,13 @@ | |
import paddle.v2 as paddle | ||
import paddle.v2.networks as networks | ||
|
||
from qa_model import QAModel | ||
from rc_model import RCModel | ||
|
||
class BiDAF(QAModel): | ||
class BiDAF(RCModel): | ||
""" | ||
Implements BiDAF. | ||
""" | ||
def __get_enc(self, input, type='q'): | ||
def _get_enc(self, input, type='q'): | ||
embs = self.get_embs(input) | ||
enc = networks.bidirectional_lstm( | ||
input=embs, | ||
|
@@ -51,7 +48,7 @@ def __get_enc(self, input, type='q'): | |
enc_dropped = self.drop_out(enc, drop_rate=0.25) | ||
return enc_dropped | ||
|
||
def __step_basic(self, h_cur, u): | ||
def _step_basic(self, h_cur, u): | ||
expanded_h = layer.expand(input=h_cur, expand_as=u) | ||
hu = layer.concat(input=[expanded_h, u]) | ||
with layer.mixed(bias_attr=False) as dot_hu: | ||
|
@@ -63,13 +60,13 @@ def __step_basic(self, h_cur, u): | |
input=cat_all) | ||
return s | ||
|
||
def __h_step(self, h_cur, u): | ||
s = self.__step_basic(h_cur, u) | ||
def _h_step(self, h_cur, u): | ||
s = self._step_basic(h_cur, u) | ||
step_max = layer.pooling(input=s, pooling_type=paddle.pooling.Max()) | ||
return step_max | ||
|
||
def __u_step(self, h_cur, u): | ||
s = self.__step_basic(h_cur, u) | ||
def _u_step(self, h_cur, u): | ||
s = self._step_basic(h_cur, u) | ||
with layer.mixed(size=1, | ||
bias_attr=False, | ||
act=Act.SequenceSoftmax()) as h_weights: | ||
|
@@ -79,45 +76,18 @@ def __u_step(self, h_cur, u): | |
pooling_type=paddle.pooling.Sum()) | ||
return u_ctx | ||
|
||
def __union_step(self, h_cur, u): | ||
s = self.__step_basic(h_cur, u) | ||
step_max = layer.pooling(input=s, pooling_type=paddle.pooling.Max()) | ||
with layer.mixed(size=1, | ||
bias_attr=False, | ||
act=Act.SequenceSoftmax()) as h_weights: | ||
h_weights += layer.identity_projection(s) | ||
applied_weights = layer.scaling(input=u, weight=h_weights) | ||
u_ctx = layer.pooling(input=applied_weights, | ||
pooling_type=paddle.pooling.Sum()) | ||
return [step_max, u_ctx] | ||
|
||
def __beta(self, h, u_expr, h_expr): | ||
def _beta(self, h, u_expr, h_expr): | ||
with layer.mixed(bias_attr=False) as dot_h_u_expr: | ||
dot_h_u_expr += layer.dotmul_operator(a=h, b=u_expr) | ||
with layer.mixed(bias_attr=False) as dot_h_h_expr: | ||
dot_h_h_expr += layer.dotmul_operator(a=h, b=h_expr) | ||
cat_all = layer.concat(input=[h, u_expr, dot_h_u_expr, dot_h_h_expr]) | ||
return cat_all | ||
|
||
def __attention_flow2(self, h, u): | ||
bs, u_expr = layer.recurrent_group( | ||
input=[h, layer.StaticInput(u)], | ||
step=self.__u_step, | ||
reverse=False) | ||
b_weights = layer.mixed(act=Act.SequenceSoftmax(), | ||
bias_attr=False, | ||
input=layer.identity_projection(bs)) | ||
h_step_scaled = layer.scaling(input=h, weight=b_weights) | ||
h_step = layer.pooling(input=h_step_scaled, | ||
pooling_type=paddle.pooling.Sum()) | ||
h_expr = layer.expand(input=h_step, expand_as=h) | ||
g = self.__beta(h, u_expr, h_expr) | ||
return g | ||
|
||
def __attention_flow(self, h, u): | ||
def _attention_flow(self, h, u): | ||
bs = layer.recurrent_group( | ||
input=[h, layer.StaticInput(u)], | ||
step=self.__h_step, | ||
step=self._h_step, | ||
reverse=False) | ||
b_weights = layer.mixed(act=Act.SequenceSoftmax(), | ||
bias_attr=False, | ||
|
@@ -128,9 +98,9 @@ def __attention_flow(self, h, u): | |
h_expr = layer.expand(input=h_step, expand_as=h) | ||
u_expr = layer.recurrent_group( | ||
input=[h, layer.StaticInput(u)], | ||
step=self.__u_step, | ||
step=self._u_step, | ||
reverse=False) | ||
g = self.__beta(h, u_expr, h_expr) | ||
g = self._beta(h, u_expr, h_expr) | ||
return g | ||
|
||
def network(self): | ||
|
@@ -143,12 +113,12 @@ def network(self): | |
""" | ||
self.check_and_create_data() | ||
self.create_shared_params() | ||
u = self.__get_enc(self.q_ids, type='q') | ||
u = self._get_enc(self.q_ids, type='q') | ||
m1s = [] | ||
m2s = [] | ||
for p in self.p_ids: | ||
h = self.__get_enc(p, type='q') | ||
g = self.__attention_flow(h, u) | ||
h = self._get_enc(p, type='q') | ||
g = self._attention_flow(h, u) | ||
m1 = networks.bidirectional_lstm( | ||
fwd_mat_param_attr=Attr.Param('_f_m1_mat.w'), | ||
fwd_bias_param_attr=Attr.Param('_f_m1.bias', | ||
|
@@ -187,17 +157,3 @@ def network(self): | |
start = self.decode('start', all_m1) | ||
end = self.decode('end', all_m2) | ||
return start, end | ||
|
||
def __fusion_layer(self, input1, input2): | ||
# fusion layer | ||
neg_input2 = layer.slope_intercept(input=input2, | ||
slope=-1.0, | ||
intercept=0.0) | ||
diff1 = layer.addto(input=[input1, neg_input2], | ||
act=Act.Identity(), | ||
bias_attr=False) | ||
diff2 = layer.mixed(bias_attr=False, | ||
input=layer.dotmul_operator(a=input1, b=input2)) | ||
|
||
fused = layer.concat(input=[input1, input2, diff1, diff2]) | ||
return fused |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,23 +16,18 @@ | |
# ============================================================================== | ||
""" | ||
Implements data parsers for different tasks on DuReader dataset. | ||
Authors: liuyuan([email protected]) | ||
Data: 2017/09/20 12:00:00 | ||
""" | ||
|
||
import copy | ||
import hashlib | ||
import itertools | ||
import logging | ||
import json | ||
import numpy as np | ||
import random | ||
import sys | ||
from collections import OrderedDict | ||
import paddle.v2 as paddle | ||
|
||
from utils import find_best_query_match | ||
from utils import find_best_question_match | ||
|
||
logger = logging.getLogger("paddle") | ||
logger.setLevel(logging.INFO) | ||
|
@@ -168,7 +163,7 @@ def __init__(self, *args, **kwargs): | |
if self.is_infer: | ||
assert self.shuffle == False, 'Shuffling is forbidden for inference' | ||
|
||
def __get_id(self, s): | ||
def _get_id(self, s): | ||
s_ids = [] | ||
if not isinstance(s, list): | ||
s = s.split(' ') | ||
|
@@ -189,18 +184,18 @@ def parse_train(self, line): | |
|
||
obj = json.loads(line.strip()) | ||
ret = [] | ||
if obj['query_type'] != 'YES_NO': | ||
if obj['question_type'] != 'YES_NO': | ||
return ret | ||
label_ids = [self.labels[l] for l in obj['yesno_answers']] | ||
query = [ | ||
question = [ | ||
self.vocab.get(x, self.unk_id) | ||
for x in obj['segmented_query']] | ||
paras = map(self.__get_id, obj['segmented_answers']) | ||
for x in obj['segmented_question']] | ||
paras = map(self._get_id, obj['segmented_answers']) | ||
|
||
if not query or not paras: | ||
if not question or not paras: | ||
return ret | ||
for para, lbl in zip(paras, label_ids): | ||
ret.append((query, para, lbl)) | ||
ret.append((question, para, lbl)) | ||
return ret | ||
|
||
def parse_infer(self, line): | ||
|
@@ -215,15 +210,15 @@ def parse_infer(self, line): | |
""" | ||
obj = json.loads(line.strip()) | ||
ret = [] | ||
paras = map(self.__get_id, obj['answers']) | ||
query = [self.vocab.get(x, self.unk_id) for x in obj['query']] | ||
paras = map(self._get_id, obj['answers']) | ||
question = [self.vocab.get(x, self.unk_id) for x in obj['question']] | ||
fake_label = 0 | ||
for idx, para in enumerate(paras): | ||
info = copy.deepcopy(obj) | ||
info['answer_idx'] = idx | ||
info['yesno_answers_ref'] = info['yesno_answers_ref'] | ||
info['yesno_answers'] = [] | ||
ret.append((query, para, fake_label, info)) | ||
ret.append((question, para, fake_label, info)) | ||
return ret | ||
|
||
def parse(self, line): | ||
|
@@ -260,7 +255,7 @@ def __init__(self, *args, **kwargs): | |
|
||
self.feeding = {name: i for i, name in enumerate(self.schema)} | ||
|
||
def __find_ans_span(self, query_tokens, span, answer_docs, docs): | ||
def _find_ans_span(self, question_tokens, span, answer_docs, docs): | ||
assert len(span) == 1, 'Multiple spans: {}'.format(span) | ||
assert len(answer_docs) == 1, \ | ||
'Multiple answer docs: {}'.format(answer_docs) | ||
|
@@ -271,7 +266,7 @@ def __find_ans_span(self, query_tokens, span, answer_docs, docs): | |
if not self.is_infer: | ||
para_idx = doc['most_related_para'] | ||
else: | ||
para_idx = find_best_query_match(doc, query_tokens) | ||
para_idx = find_best_question_match(doc, question_tokens) | ||
para = doc['segmented_paragraphs'][para_idx] | ||
if len(para) == 0: | ||
continue | ||
|
@@ -291,8 +286,8 @@ def __find_ans_span(self, query_tokens, span, answer_docs, docs): | |
para_tokens.append(para) | ||
return selected_paras, para_tokens | ||
|
||
def __make_sample(self, query_ids, para_infos): | ||
def __get_label(idx, ref): | ||
def _make_sample(self, question_ids, para_infos): | ||
def _get_label(idx, ref): | ||
ret = [0.0] * len(ref) | ||
if idx > 0: | ||
ret[idx] = 1.0 | ||
|
@@ -305,23 +300,23 @@ def __get_label(idx, ref): | |
selected += [default_para_info] * (self.doc_num - len(selected)) | ||
for para_ids, ans_span in selected: | ||
s, e = ans_span | ||
start_label = __get_label(s, para_ids) | ||
end_label = __get_label(e, para_ids) | ||
start_label = _get_label(s, para_ids) | ||
end_label = _get_label(e, para_ids) | ||
paras.append(para_ids) | ||
start_labels.append(start_label) | ||
end_labels.append(end_label) | ||
para_lens.append([[len(para_ids)]]) | ||
sample = [query_ids] + paras + para_lens + start_labels + end_labels | ||
sample = [question_ids] + paras + para_lens + start_labels + end_labels | ||
return sample | ||
|
||
def __get_infer_info(self, obj, paras): | ||
def _get_infer_info(self, obj, paras): | ||
info = {} | ||
info['tokens'] = list(itertools.chain(*paras)) | ||
info['answers'] = [] | ||
info['answers_ref'] = obj.get('segmented_answers', []) | ||
info['query'] = obj['segmented_query'] | ||
info['query_id'] = obj['query_id'] | ||
info['query_type'] = obj['query_type'] | ||
info['question'] = obj['segmented_question'] | ||
info['question_id'] = obj['question_id'] | ||
info['question_type'] = obj['question_type'] | ||
info['yesno_answers_ref'] = obj.get('yesno_answers', []) | ||
info['yesno_answers'] = [] | ||
info['entities'] = obj.get('entity_answers', [[]]) | ||
|
@@ -345,20 +340,19 @@ def parse(self, line): | |
if obj['answer_docs'][0] > 5: | ||
logger.info('skip, answer doc out of range.') | ||
return ret | ||
q_ids = [self.vocab.get(x, self.unk_id) for x in obj['segmented_query']] | ||
q_ids = [self.vocab.get(x, self.unk_id) for x in obj['segmented_question']] | ||
if len(q_ids) == 0: | ||
return ret | ||
selected_paras, para_tokens = self.__find_ans_span( | ||
obj['segmented_query'], | ||
selected_paras, para_tokens = self._find_ans_span( | ||
obj['segmented_question'], | ||
obj['answer_spans'], | ||
obj['answer_docs'], | ||
obj['documents']) | ||
if not selected_paras: | ||
return ret | ||
sample = self.__make_sample(q_ids, selected_paras) | ||
#assert len(sample) == len(self.schema) | ||
sample = self._make_sample(q_ids, selected_paras) | ||
if self.is_infer: | ||
sample.append(self.__get_infer_info(obj, para_tokens)) | ||
sample.append(self._get_infer_info(obj, para_tokens)) | ||
ret.append(sample) | ||
return ret | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,6 @@ | |
# ============================================================================== | ||
""" | ||
This module implements an inferer for basic inference and evaluation functions. | ||
Authors: liuyuan([email protected]) | ||
Date: 2017/09/20 12:00:00 | ||
""" | ||
|
||
import argparse | ||
|
@@ -52,9 +49,9 @@ def __init__(self, | |
self.test_reader = datasets[1] | ||
self.feeding = datasets[1].feeding | ||
self.costs = [] | ||
self.__prepare() | ||
self._prepare() | ||
|
||
def __prepare(self): | ||
def _prepare(self): | ||
# prepare reader | ||
self.test_reader = paddle.batch( | ||
reader=self.test_reader.create_reader(), | ||
|
@@ -97,7 +94,7 @@ def get_infer_file(self): | |
is_exist = os.path.isfile(infer_file) | ||
return is_exist, infer_file | ||
|
||
def run(self): | ||
def start(self): | ||
""" | ||
Runs the whole inferring process. | ||
""" | ||
|
Oops, something went wrong.