Skip to content

Commit

Permalink
RoBERTa embeddings are no longer a type of BERT embeddings (allenai#4771
Browse files Browse the repository at this point in the history
)

* RoBERTa embeddings are no longer a type of BERT embeddings

* Changelog
  • Loading branch information
dirkgr authored Nov 7, 2020
1 parent 23f0a8a commit 92a844a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Previously, we would compute gradients from the top of the transformer, after aggregation from
wordpieces to tokens, which gives results that are not very informative. Now, we compute gradients
with respect to the embedding layer, and aggregate wordpieces to tokens separately.
- Fixed the heuristics for finding embedding layers in the case of RoBERTa. An update in the
`transformers` library broke our old heuristic.


## [v1.2.0](https://github.com/allenai/allennlp/releases/tag/v1.2.0) - 2020-10-29
Expand Down
3 changes: 3 additions & 0 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,6 +1772,7 @@ def find_embedding_layer(model: torch.nn.Module) -> torch.nn.Module:
from transformers.modeling_gpt2 import GPT2Model
from transformers.modeling_bert import BertEmbeddings
from transformers.modeling_albert import AlbertEmbeddings
from transformers.modeling_roberta import RobertaEmbeddings
from allennlp.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder
from allennlp.modules.text_field_embedders.basic_text_field_embedder import (
BasicTextFieldEmbedder,
Expand All @@ -1781,6 +1782,8 @@ def find_embedding_layer(model: torch.nn.Module) -> torch.nn.Module:
for module in model.modules():
if isinstance(module, BertEmbeddings):
return module.word_embeddings
if isinstance(module, RobertaEmbeddings):
return module.word_embeddings
if isinstance(module, AlbertEmbeddings):
return module.word_embeddings
if isinstance(module, GPT2Model):
Expand Down

0 comments on commit 92a844a

Please sign in to comment.