Skip to content

Commit

Permalink
Adding Whole Word Masking
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdevlin-google committed May 31, 2019
1 parent d66a146 commit 0fce551
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 19 deletions.
54 changes: 52 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,56 @@
# BERT

**\*\*\*\*\* New May 31st, 2019: Whole Word Masking Models \*\*\*\*\***

This is a release of several new models which were the result of an improvement
the pre-processing code.

In the original pre-processing code, we randomly select WordPiece tokens to
mask. For example:

`Input Text: the man jumped up , put his basket on phil ##am ##mon ' s head`
`Original Masked Input: [MASK] man [MASK] up , put his [MASK] on phil
[MASK] ##mon ' s head`

The new technique is called Whole Word Masking. In this case, we always mask
*all* of the the tokens corresponding to a word at once. The overall masking
rate remains the same.

`Whole Word Masked Input: the man [MASK] up , put his basket on [MASK] [MASK]
[MASK] ' s head`

The training is identical -- we still predict each masked WordPiece token
independently. The improvement comes from the fact that the original prediction
task was too 'easy' for words that had been split into multiple WordPieces.

This can be enabled during data generation by passing the flag
`--do_whole_word_mask=True` to `create_pretraining_data.py`.

Pre-trained models with Whole Word Masking are linked below. The data and
training were otherwise identical, and the models have identical structure and
vocab to the original models. We only include BERT-Large models. When using
these models, please make it clear in the paper that you are using the Whole
Word Masking variant of BERT-Large.

* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip)**:
24-layer, 1024-hidden, 16-heads, 340M parameters

* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip)**:
24-layer, 1024-hidden, 16-heads, 340M parameters

Model | SQUAD 1.1 F1/EM | Multi NLI Accuracy
---------------------------------------- | :-------------: | :----------------:
BERT-Large, Uncased (Original) | 91.0/84.3 | 86.05
BERT-Large, Uncased (Whole Word Masking) | 92.8/86.7 | 87.07
BERT-Large, Cased (Original) | 91.5/84.8 | 86.09
BERT-Large, Cased (Whole Word Masking) | 92.9/86.7 | 86.46

**\*\*\*\*\* New February 7th, 2019: TfHub Module \*\*\*\*\***

BERT has been uploaded to [TensorFlow Hub](https://tfhub.dev). See
`run_classifier_with_tfhub.py` for an example of how to use the TF Hub module,
or run an example in the browser on [Colab](https://colab.sandbox.google.com/github/google-research/bert/blob/master/predicting_movie_reviews_with_bert_on_tf_hub.ipynb).
`run_classifier_with_tfhub.py` for an example of how to use the TF Hub module,
or run an example in the browser on
[Colab](https://colab.sandbox.google.com/github/google-research/bert/blob/master/predicting_movie_reviews_with_bert_on_tf_hub.ipynb).

**\*\*\*\*\* New November 23rd, 2018: Un-normalized multilingual model + Thai +
Mongolian \*\*\*\*\***
Expand Down Expand Up @@ -225,6 +271,10 @@ using your own script.)**

The links to the models are here (right-click, 'Save link as...' on the name):

* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Uncased`](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip)**:
12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Large, Uncased`](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip)**:
Expand Down
61 changes: 44 additions & 17 deletions create_pretraining_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")

flags.DEFINE_bool(
"do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.")

flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")

flags.DEFINE_integer("max_predictions_per_seq", 20,
Expand Down Expand Up @@ -343,7 +347,20 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
cand_indexes.append(i)
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])

rng.shuffle(cand_indexes)

Expand All @@ -354,29 +371,39 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,

masked_lms = []
covered_indexes = set()
for index in cand_indexes:
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indexes:
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
covered_indexes.add(index)
for index in index_set:
covered_indexes.add(index)

masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]

output_tokens[index] = masked_token
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]

masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
output_tokens[index] = masked_token

masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)

masked_lm_positions = []
Expand Down

0 comments on commit 0fce551

Please sign in to comment.