Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/baidu/DuReader
Browse files Browse the repository at this point in the history
  • Loading branch information
liuyuuan committed Nov 16, 2017
2 parents eb61409 + b5ba7b9 commit e436c82
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 44 deletions.
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ Before training the model, we have to make sure that the data is ready. For prep
```
python run.py --prepare --task zhidao
```
You can choose which dataset to use by set the `--task` as `search`, `zhidao` or `both`.
You can specify the files for train/dev/test by setting the `train_files`/`dev_files`/`test_files`. By default, we use the data in `data/demo/`

#### Training
To train the reading comprehension model, you can specify the model type by using `--algo [BIDAF|MLSTM]` and you can also set the hyper-parameters such as the learning rate by using `--learning_rate NUM`. For example, to train a BIDAF model on Zhidao Dataset for 10 epochs, you can run:
To train the reading comprehension model, you can specify the model type by using `--algo [BIDAF|MLSTM]` and you can also set the hyper-parameters such as the learning rate by using `--learning_rate NUM`. For example, to train a BIDAF model for 10 epochs, you can run:

```
python run.py --task zhidao --algo BIDAF --epochs 10
python run.py --train --algo BIDAF --epochs 10
```

The training process includes an evaluation on the dev set after each training epoch. By default, the model with the least Bleu-4 score on the dev set will be saved.
Expand All @@ -111,19 +111,18 @@ The training process includes an evaluation on the dev set after each training e
To conduct a single evaluation on the dev set with the the model already trained, you can run the following command:

```
python run.py --evaluate --task zhidao
python run.py --evaluate --algo BIDAF
```

#### Prediction
You can predict answers for the samples in dev set and test set using the following command:
You can also predict answers for the samples in some files using the following command:

```
python run.py --predict --task zhidao
python run.py --predict --algo BIDAF --test_files ../data/demo/search.dev.json
```

By default, the results are saved at `../data/results/` folder. You can change this by specifying `--result_dir DIR_PATH`.


## Copyright and License
Copyright 2017 Baidu.com, Inc. All Rights Reserved

Expand Down
32 changes: 17 additions & 15 deletions tensorflow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,38 +29,40 @@ class BRCDataset(object):
"""
This module implements the APIs for loading and using baidu reading comprehension dataset
"""
def __init__(self, data_dir, task, max_p_num, max_p_len, max_q_len,
train=True, dev=True, test=True):
def __init__(self, max_p_num, max_p_len, max_q_len,
train_files=[], dev_files=[], test_files=[]):
self.logger = logging.getLogger("brc")
self.data_dir = data_dir
self.task = task
self.max_p_num = max_p_num
self.max_p_len = max_p_len
self.max_q_len = max_q_len

self.train_set = self._load_dataset(self.task + '.train') if train else None
if train:
self.train_set, self.dev_set, self.test_set = [], [], []
if train_files:
for train_file in train_files:
self.train_set += self._load_dataset(train_file, train=True)
self.logger.info('Train set size: {} questions.'.format(len(self.train_set)))

self.dev_set = self._load_dataset(self.task + '.dev') if dev else None
if dev:
if dev_files:
for dev_file in dev_files:
self.dev_set += self._load_dataset(dev_file)
self.logger.info('Dev set size: {} questions.'.format(len(self.dev_set)))

self.test_set = self._load_dataset(self.task + '.test') if test else None
if test:
if test_files:
for test_file in test_files:
self.test_set += self._load_dataset(test_file)
self.logger.info('Test set size: {} questions.'.format(len(self.test_set)))

def _load_dataset(self, prefix):
def _load_dataset(self, data_path, train=False):
"""
Loads the dataset
Args:
prefix: task + 'train/dev/test' indicating the data file to load
data_path: the data file to load
"""
with open(os.path.join(self.data_dir, prefix + '.json')) as fin:
with open(data_path) as fin:
data_set = []
for lidx, line in enumerate(fin):
sample = json.loads(line.strip())
if 'train' in prefix:
if train:
if len(sample['answer_spans']) == 0:
continue
if sample['answer_spans'][0][1] >= self.max_p_len:
Expand All @@ -73,7 +75,7 @@ def _load_dataset(self, prefix):

sample['passages'] = []
for d_idx, doc in enumerate(sample['documents']):
if 'train' in prefix:
if train:
most_related_para = doc['most_related_para']
sample['passages'].append(
{'passage_tokens': doc['segmented_paragraphs'][most_related_para],
Expand Down
52 changes: 30 additions & 22 deletions tensorflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def parse_args():
help='evaluate the model on dev set')
parser.add_argument('--predict', action='store_true',
help='predict the answers for test set with trained model')
parser.add_argument('--task', type=str, default='both',
help='reading comprehension on search/zhidao/both dataset')
parser.add_argument('--gpu', type=str, default='0',
help='specify gpu device')

Expand Down Expand Up @@ -81,6 +79,14 @@ def parse_args():
help='max length of answer')

path_settings = parser.add_argument_group('path settings')
path_settings.add_argument('--train_files', nargs='+',
default=['../data/demo/search.train.json'],
help='list of files that contains preprocessed train data')
path_settings.add_argument('--dev_files', nargs='+',
default=['../data/demo/search.dev.json'],
help='list of files that contains preprocessed dev data')
path_settings.add_argument('--test_files', nargs='+', default=[],
help='list of files that contains preprocessed test data')
path_settings.add_argument('--brc_dir', default='../data/baidu',
help='the dir with preprocessed baidu reading comprehension data')
path_settings.add_argument('--vocab_dir', default='../data/vocab/',
Expand All @@ -101,17 +107,17 @@ def prepare(args):
checks data, creates the directories, prepare the vocabulary and embeddings
"""
logger = logging.getLogger("brc")
logger.info('Checking the data for {} task...'.format(args.task))
for suffix in ['train.json', 'dev.json', 'test.json']:
data_path = os.path.join(args.brc_dir, args.task + '.' + suffix)
logger.info('Checking the data files...')
for data_path in args.train_files + args.dev_files + args.test_files:
assert os.path.exists(data_path), '{} file does not exist.'
logger.info('Preparing the directories...')
for dir_path in [args.vocab_dir, args.model_dir, args.result_dir, args.summary_dir]:
if not os.path.exists(dir_path):
os.makedirs(dir_path)

logger.info('Building vocabulary...')
brc_data = BRCDataset(args.brc_dir, args.task, args.max_p_num, args.max_p_len, args.max_q_len)
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
args.train_files, args.dev_files, args.test_files)
vocab = Vocab(lower=True)
for word in brc_data.word_iter('train'):
vocab.add(word)
Expand All @@ -126,7 +132,7 @@ def prepare(args):
vocab.randomly_init_embeddings(args.embed_size)

logger.info('Saving vocab...')
with open(os.path.join(args.vocab_dir, args.task + '.' + 'vocab.data'), 'wb') as fout:
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'wb') as fout:
pickle.dump(vocab, fout)

logger.info('Done with preparing!')
Expand All @@ -138,65 +144,67 @@ def train(args):
"""
logger = logging.getLogger("brc")
logger.info('Load data_set and vocab...')
with open(os.path.join(args.vocab_dir, args.task + '.' + 'vocab.data'), 'rb') as fin:
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
vocab = pickle.load(fin)
brc_data = BRCDataset(args.brc_dir, args.task, args.max_p_num, args.max_p_len, args.max_q_len)
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
args.train_files, args.dev_files)
logger.info('Converting text into ids...')
brc_data.convert_to_ids(vocab)
logger.info('Initialize the model...')
rc_model = RCModel(vocab, args)
logger.info('Training the model...')
rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir,
save_prefix=args.task + '.' + args.algo,
save_prefix=args.algo,
dropout_keep_prob=args.dropout_keep_prob)
logger.info('Done with model training!')


def evaluate(args):
"""
evaluate the trained model on dev set
evaluate the trained model on dev files
"""
logger = logging.getLogger("brc")
logger.info('Load data_set and vocab...')
with open(os.path.join(args.vocab_dir, args.task + '.' + 'vocab.data'), 'rb') as fin:
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
vocab = pickle.load(fin)
brc_data = BRCDataset(args.brc_dir, args.task,
args.max_p_num, args.max_p_len, args.max_q_len, train=False, test=False)
assert len(args.dev_files) > 0, 'No dev files are provided.'
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, dev_files=args.dev_files)
logger.info('Converting text into ids...')
brc_data.convert_to_ids(vocab)
logger.info('Restoring the model...')
rc_model = RCModel(vocab, args)
rc_model.restore(model_dir=args.model_dir, model_prefix=args.task + '.' + args.algo)
rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
logger.info('Evaluating the model on dev set...')
dev_batches = brc_data.gen_mini_batches('dev', args.batch_size,
pad_id=vocab.get_id(vocab.pad_token), shuffle=False)
dev_loss, dev_bleu_rouge = rc_model.evaluate(
dev_batches, result_dir=args.result_dir, result_prefix=args.task + '.dev.predicted')
dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted')
logger.info('Loss on dev set: {}'.format(dev_loss))
logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
logger.info('Predicted answers are saved to {}'.format(os.path.join(args.result_dir)))


def predict(args):
"""
predicts answers for test set
predicts answers for test files
"""
logger = logging.getLogger("brc")
logger.info('Load data_set and vocab...')
with open(os.path.join(args.vocab_dir, args.task + '.' + 'vocab.data'), 'rb') as fin:
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
vocab = pickle.load(fin)
brc_data = BRCDataset(args.brc_dir, args.task,
args.max_p_num, args.max_p_len, args.max_q_len, train=False, dev=False)
assert len(args.test_files) > 0, 'No test files are provided.'
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
test_files=args.test_files)
logger.info('Converting text into ids...')
brc_data.convert_to_ids(vocab)
logger.info('Restoring the model...')
rc_model = RCModel(vocab, args)
rc_model.restore(model_dir=args.model_dir, model_prefix=args.task + '.' + args.algo)
rc_model.restore(model_dir=args.model_dir, model_prefix=args.algo)
logger.info('Predicting answers for test set...')
test_batches = brc_data.gen_mini_batches('test', args.batch_size,
pad_id=vocab.get_id(vocab.pad_token), shuffle=False)
rc_model.evaluate(test_batches,
result_dir=args.result_dir, result_prefix=args.task + '.test.predicted')
result_dir=args.result_dir, result_prefix='test.predicted')


def run():
Expand Down

0 comments on commit e436c82

Please sign in to comment.