Skip to content

Commit

Permalink
Merge pull request huggingface#1162 from huggingface/xlnet-bias
Browse files Browse the repository at this point in the history
XLNet bias fix on resize embeddings (cf huggingface#1124)
  • Loading branch information
thomwolf authored Sep 2, 2019
2 parents 7f52243 + ea86bef commit 0287d26
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pytorch_transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,14 @@ def _tie_or_clone_weights(self, first_module, second_module):
else:
first_module.weight = second_module.weight

if hasattr(first_module, 'bias') and first_module.bias is not None:
first_module.bias.data = torch.nn.functional.pad(
first_module.bias.data,
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
'constant',
0
)

def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Expand Down

0 comments on commit 0287d26

Please sign in to comment.