Skip to content

Commit

Permalink
Update yesno to 3 classes.
Browse files Browse the repository at this point in the history
support multiple file names for trainset and testset

Change-Id: I26fd79bed1129ffc0d2ab875cd2f0d34bc8c0c54
  • Loading branch information
liuyuuan committed Nov 13, 2017
1 parent ffe31b5 commit b141cde
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 23 deletions.
26 changes: 14 additions & 12 deletions paddle/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self,
append_raw=False,
is_infer=False,
max_p_len=500):
self.file_name = file_name
self.file_names = file_name
self.data = []
self.raw = []
self.vocab = self.read_vocab(vocab_file, vocab_size) \
Expand All @@ -70,9 +70,10 @@ def load(self):
Loads all data records into self.data.
"""
self.data = []
with open(self.file_name, 'r') as src:
for line in src:
self.data += self.parse(line.strip())
for file_name in self.file_names:
with open(self.file_name, 'r') as src:
for line in src:
self.data += self.parse(line.strip())
if self.shuffle:
logger.info('Shuffling data...')
random.shuffle(self.data)
Expand Down Expand Up @@ -138,13 +139,14 @@ def _reader_preload():
yield line

def _reader_stream():
with open(self.file_name, 'r') as fn:
for line in fn:
data = self.parse(line.strip())
if not data:
continue
for d in data:
yield d
for file_name in self.file_names:
with open(self.file_name, 'r') as fn:
for line in fn:
data = self.parse(line.strip())
if not data:
continue
for d in data:
yield d

if not self.preload:
return _reader_stream
Expand All @@ -156,7 +158,7 @@ class DuReaderYesNo(Dataset):
Implements parser for yesno task.
"""
def __init__(self, *args, **kwargs):
self.labels = {'None': 0, 'Yes': 1, 'No': 2, 'Depends': 3}
self.labels = {'Yes': 0, 'No': 1, 'Depends': 2}
super(DuReaderYesNo, self).__init__(*args, **kwargs)
self.schema = ['q_ids', 'a_ids', 'label']
self.feeding = {name: i for i, name in enumerate(self.schema)}
Expand Down
4 changes: 2 additions & 2 deletions paddle/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def parse_args():
Parses command line arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--trainset', help='train dataset')
parser.add_argument('--testset', help='test dataset')
parser.add_argument('--trainset', nargs='+', help='train dataset')
parser.add_argument('--testset', nargs='+', help='test dataset')
parser.add_argument('--test_period', type=int, default=10)
parser.add_argument('--vocab_file', help='dict')
parser.add_argument('--batch_size', help='batch size',
Expand Down
13 changes: 4 additions & 9 deletions paddle/yesno.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
# ==============================================================================
"""
This module implements an opinion classification model to classify a
question answer pair into 4 categories: None(no opinion), Yes(positive opinion),
question answer pair into 3 categories: Yes(positive opinion),
No(negative opinion), Depends(depends on conditions).
Authors: liuyuan([email protected])
Date: 2017/09/20 12:00:00
"""

import logging
Expand Down Expand Up @@ -49,9 +46,9 @@ def __init__(self, name, inputs, *args, **kwargs):
self.emb_dim = kwargs['emb_dim']
self.vocab_size = kwargs['vocab_size']
self.is_infer = kwargs['is_infer']
self.label_dim = 4
self.static_emb = kwargs['static_emb']
self.labels = ['None', 'Yes', 'No', 'Depends']
self.label_dim = 3
self.static_emb = kwargs.get('static_emb', False)
self.labels = ['Yes', 'No', 'Depends']
self.label_dict = {v: idx for idx, v in enumerate(self.labels)}
super(OpinionClassifier, self).__init__(name, inputs, *args, **kwargs)

Expand Down Expand Up @@ -133,8 +130,6 @@ def train(self):
input=cls, name='label1', label=self.label, positive_label=1)
evaluator_2 = paddle.evaluator.precision_recall(
input=cls, name='label2', label=self.label, positive_label=2)
evaluator_3 = paddle.evaluator.precision_recall(
input=cls, name='label3', label=self.label, positive_label=3)
evaluator_all = paddle.evaluator.precision_recall(
input=cls, name='label_all', label=self.label)
return loss
Expand Down

0 comments on commit b141cde

Please sign in to comment.