Skip to content

Commit

Permalink
Improve this project to reach the scores reported in the original pap…
Browse files Browse the repository at this point in the history
…er (#5) and make it more efficient and more readable.

major modifications:
- use one-hot encoding for POS & NER features instead of using seperate embeddings; change dropout to 0.4 accordingly.
- if words are identical after normalization, their embeddings are averaged.
- only replace one invisible character ('\s') with one space. More training examples are reserved because merging multiple spaces to one results in misalignment of answers.

improved efficiency:
- Code in "prepro.py" is refactored and made more efficient. Time consumption of "prepro.py" is reduced from 515s to 172s (3x faster) on a machine with 8 i7 CPUs and 16GB RAM.
- function "get_answer_index" is simplified and much more readable

other improvements:
- id of each example is reserved for ease of debugging.
- vocabulary of POS and NER tags is saved for ease of debugging.
- other tiny improvements to make the code more readable.
  • Loading branch information
hitvoice committed Oct 25, 2017
1 parent 6886ee7 commit 14cd4c2
Show file tree
Hide file tree
Showing 9 changed files with 2,476 additions and 3,299 deletions.
91 changes: 17 additions & 74 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
DrQA
---
> Facebook has finally released its official implementation of the paper half a month after this project emerged. This project will stick to its original purpose: it will provide a **clean and minimized** implementation of the paper, so one can quickly read through the code to understand the model, easily make some modifications to test new ideas, or plug the necessary parts into a larger framework. If you plan to deploy this model in a more industrial environment, please refer to [facebookresearch/DrQA](https://github.com/facebookresearch/DrQA). If you would like to embed this model in a chatbot framework, please refer to [facebookresearch/ParlAI](https://github.com/facebookresearch/ParlAI/).

A pytorch implementation of the ACL 2017 paper [Reading Wikipedia to Answer Open-Domain Questions](http://www-cs.stanford.edu/people/danqi/papers/acl2017.pdf) (DrQA).

Reading comprehension is a task to produce an answer when given a question and one or more pieces of evidence (usually natural language paragraphs). Compared to question answering over knowledge bases, reading comprehension models are more flexible and have revealed a great potential for zero-shot learning.

[SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) is a reading comprehension benchmark where there's only a single piece of evidence and the answer is guaranteed to be a part of the evidence. Since the publication of SQuAD dataset, there has been fast progress in the research of reading comprehension and a bunch of great models have come out. DrQA is one that is conceptually simpler than most others but still yields strong performance even as a single model.

The motivation for this project is to offer a clean version of DrQA for the machine reading comprehension task, so one can quickly do some modifications and try out new ideas. Most of the model code is borrowed from [ParlAI](https://github.com/facebookresearch/ParlAI/). Click [here](#detailed-comparisons) to see the comparison with what's described in the original paper and with two "offical" projects ParlAI and DrQA.
The motivation for this project is to offer a clean version of DrQA for the machine reading comprehension task, so one can quickly do some modifications and try out new ideas. Click [here](#detailed-comparisons) to see the comparison with what's described in the original paper and with two "official" projects ParlAI and DrQA.

## Requirements
- python >=3.5
- pytorch 0.2.0 (please refer to [the previous version](https://github.com/hitvoice/DrQA/tree/bc0152c7ad69c56fda23f50adabd4355559b3a74) if you use pytorch 0.1.12)
- numpy
- pandas
- msgpack
- spacy 1.x

## Quick Start
### Setup
- download the project via `git clone https://github.com/hitvoice/DrQA.git; cd DrQA`
- make sure python 3 and pip is installed.
- make sure python 3, pip, wget and unzip are installed.
- install [pytorch](http://pytorch.org/) matched with your OS, python and cuda versions.
- install the remaining requirements via `pip install -r requirements.txt`
- download the SQuAD datafile, GloVe word vectors and Spacy English language models using `bash download.sh`.
Expand All @@ -32,92 +28,39 @@ The motivation for this project is to offer a clean version of DrQA for the mach
```bash
# prepare the data
python prepro.py
# train for 20 epoches with batchsize 32
python train.py -e 20 -bs 32
# train for 40 epochs with batchsize 32
python train.py -e 40 -bs 32
```

## Results
### EM & F1
||EM|F1|
|---|---|---|
|in original paper|69.5|78.8|
|in this project|69.3|78.6|
|in the original paper|69.5|78.8|
|in this project|69.64|78.76|
|offical(Spacy)|69.71|78.94|
|offical(CoreNLP)|69.76|79.09|

Compared to the implementation in ParlAI:
Compared with the official implementation:

<img src="https://rawgit.com/hitvoice/DrQA/master/img/em.svg" width="500">

<img src="https://rawgit.com/hitvoice/DrQA/master/img/f1.svg" width="500">

The command to run the ParlAI implementation:
```bash
git clone https://github.com/facebookresearch/ParlAI.git ~/ParlAI
cd ~/ParlAI; python setup.py develop
python examples/train_model.py -m drqa -t squad -bs 32 -e 30 -vp 20 -dbf True --validation-every-n-secs 400 -mf /home/ubuntu/ParlAI/models --embedding_file /home/ubuntu/glove/glove.840B.300d.txt --embedding_dim 300 --fix_embeddings False --tune_partial 1000 --dropout_rnn 0.3 --dropout_emb 0.3 | tee output.log
```

### training time
The experiments are run on a machine with a single NVIDIA Tesla K80 GPU, 8 CPUs (2.3GHz) and 59G RAM.

||training time (seconds/epoch)|
|---|---|
|implementation in this project|770|
|implementation in ParlAI|850|

### related discussions
Here's what the paper says when introducing the embedding layer:
> We keep most of the pre-trained word embeddings fixed and only fine-tune the 1000 **most frequent question words** because the representations of some keywords such as *what*, *how*, *which*, *many* could be crucial for QA systems.
So what's the difference between most frequent words and most frequent question words? Here are the top 20 words of each:

||sort by all|sort by question|
|---|---|---|
|1|the|?|
|2|,|the|
|3|of|What|
|4|.|of|
|5|and|in|
|6|in|to|
|7|to|was|
|8|a|is|
|9|"|did|
|10|is|what|
|11|-|a|
|12|was|'s|
|13|The|Who|
|14|as|How|
|15|(|for|
|16|)|and|
|17|?|,|
|18|for|are|
|19|by|many|
|20|that|When|

The venn diagram:

<img src="https://rawgit.com/hitvoice/DrQA/master/img/vocab.svg" width="500">

26% words are different in top 1000 words of the two vocabularies. When tuning 1000 most frequent question words instead of 1000 most frequent words, about 1.5% boost of the F1 score is observed.

### Detailed Comparisons

Compared to what's described in the original paper:
- The grammatical features are generated by SpaCy instead of Stanford Core NLP. It's much faster (5 minutes vs 20 hours) but less accurate.
- The training samples are shuffled completely in each epoch. Performance degrades significantly when sorting the samples by length, dividing into mini-batches and then shuffle the mini-batches as recorded in the paper.
- The original paper does not make it clear whether POS and NER is a one-hot feature or has its own trainable embedding matrix. This implementation treats these two tags as discrete features with their own embedding matrices, which is found to be better in performance and makes the model more flexible.
- The grammatical features are generated by [spaCy](https://spacy.io) instead of [Stanford CoreNLP](https://stanfordnlp.github.io/CoreNLP/). It's much faster and produces similar scores.

Compared to the code in ParlAI:
- The DrQA model is not longer wrapped in a chatbot framework, which makes the code more readable, easier to modify and is faster to train. The preprocessing for text corpus is performed only once, while in a dialog framework raw text is transmitted each time and preprocessing for the same text must be done again and again.
- This is a full implementation of the original paper, while the model in ParlAI is a partial implementation, missing all grammatical features (lemma, POS tags and named entity tags).
- When tuning top-k embeddings, the model will tune the embeddings of top-k question words as the original paper states, while the word dictionary in ParlAI is sorted by the frequency of all words. This does make a difference (see the discussion above).
- Some minor bug fixes and enhancements. Some of them have been merged into ParlAI.

Compared to the code in facebookresearch/DrQA:
- This project is much more light-weighted, while lacking the document retriever, the inference and interactive inference API, the extendibility to other datasets and some other enhancements.
- The implementation in facebookresearch/DrQA tokenizes the dataset using a Java-coded Stanford CoreNLP, while in this project we use a faster and simpler Spacy.
- The implementation in facebookresearch/DrQA treats the POS and NER tags as one-hot features, this implementation treats these two tags as discrete features with their own embedding matrice.
Compared to the code in [facebookresearch/DrQA](https://github.com/facebookresearch/DrQA/):
- This project is much more light-weighted and focusing solely on training and evaluating on SQuAD dataset while lacking the document retriever, the interactive inference API, and some other features.
- The implementation in facebookresearch/DrQA is able to train on multiple GPUs, while (currently and for simplicity) in this implementation we only support single-GPU training.

Compared to the code in [facebookresearch/ParlAI](https://github.com/facebookresearch/ParlAI/):
- The DrQA model is no longer wrapped in a chatbot framework, which makes the code more readable, easier to modify and is faster to train. The preprocessing for text corpus is performed only once, while in a dialog framework raw text is transmitted each time and preprocessing for the same text must be done again and again.
- This is a full implementation of the original paper, while the model in ParlAI is a partial implementation, missing all grammatical features (lemma, POS tags and named entity tags).
- Some minor bug fixes. Some of them have been merged into ParlAI.

### About
Maintainer: [Runqi Yang](https://hitvoice.github.io/about/).

Expand Down
11 changes: 9 additions & 2 deletions drqa/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _forward_padded(self, x, x_mask):
"""Slower (significantly), but more precise,
encoding that handles padding."""
# Compute sorted sequence lengths
lengths = x_mask.data.eq(0).long().sum(1)
lengths = x_mask.data.eq(0).long().sum(1).squeeze()
_, idx_sort = torch.sort(lengths, dim=0, descending=True)
_, idx_unsort = torch.sort(idx_sort, dim=0)

Expand Down Expand Up @@ -130,6 +130,13 @@ def _forward_padded(self, x, x_mask):
output = output.transpose(0, 1)
output = output.index_select(0, idx_unsort)

# Pad up to original batch sequence length
if output.size(1) != x_mask.size(1):
padding = torch.zeros(output.size(0),
x_mask.size(1) - output.size(1),
output.size(2)).type(output.data.type())
output = torch.cat([output, Variable(padding)], 1)

# Dropout on output layer
if self.dropout_output and self.dropout_rate > 0:
output = F.dropout(output,
Expand Down Expand Up @@ -246,7 +253,7 @@ def uniform_weights(x, x_mask):
if x.data.is_cuda:
alpha = alpha.cuda()
alpha = alpha * x_mask.eq(0).float()
alpha = alpha / alpha.sum(1, keepdim=True).expand(alpha.size())
alpha = alpha / alpha.sum(1).expand(alpha.size())
return alpha


Expand Down
16 changes: 5 additions & 11 deletions drqa/rnn_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, opt, padding_idx=0, embedding=None):
self.embedding = nn.Embedding(embedding.size(0),
embedding.size(1),
padding_idx=padding_idx)
self.embedding.weight.data = embedding
self.embedding.weight.data[2:, :] = embedding[2:, :]
if opt['fix_embeddings']:
assert opt['tune_partial'] == 0
for p in self.embedding.parameters():
Expand All @@ -40,10 +40,6 @@ def __init__(self, opt, padding_idx=0, embedding=None):
self.embedding = nn.Embedding(opt['vocab_size'],
opt['embedding_dim'],
padding_idx=padding_idx)
if opt['pos']:
self.pos_embedding = nn.Embedding(opt['pos_size'], opt['pos_dim'])
if opt['ner']:
self.ner_embedding = nn.Embedding(opt['ner_size'], opt['ner_dim'])
# Projection for attention weighted question
if opt['use_qemb']:
self.qemb_match = layers.SeqAttnMatch(opt['embedding_dim'])
Expand All @@ -53,9 +49,9 @@ def __init__(self, opt, padding_idx=0, embedding=None):
if opt['use_qemb']:
doc_input_size += opt['embedding_dim']
if opt['pos']:
doc_input_size += opt['pos_dim']
doc_input_size += opt['pos_size']
if opt['ner']:
doc_input_size += opt['ner_dim']
doc_input_size += opt['ner_size']

# RNN document encoder
self.doc_rnn = layers.StackedBRNN(
Expand Down Expand Up @@ -131,11 +127,9 @@ def forward(self, x1, x1_f, x1_pos, x1_ner, x1_mask, x2, x2_mask):
x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask)
drnn_input_list.append(x2_weighted_emb)
if self.opt['pos']:
x1_pos_emb = self.pos_embedding(x1_pos)
drnn_input_list.append(x1_pos_emb)
drnn_input_list.append(x1_pos)
if self.opt['ner']:
x1_ner_emb = self.ner_embedding(x1_ner)
drnn_input_list.append(x1_ner_emb)
drnn_input_list.append(x1_ner)
drnn_input = torch.cat(drnn_input_list, 2)
# Encode document with RNN
doc_hiddens = self.doc_rnn(drnn_input, x1_mask)
Expand Down
Loading

0 comments on commit 14cd4c2

Please sign in to comment.