Skip to content

Commit

Permalink
tie the weights of lm_head.weight and transformer.wte.weight, i.e. th…
Browse files Browse the repository at this point in the history
…e last linear layer of decoder and the token embeddings.
  • Loading branch information
karpathy committed Jan 14, 2023
1 parent 32b4f08 commit 7c82885
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ def __init__(self, config):
ln_f = nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.lm_head.weight = self.transformer.wte.weight # https://paperswithcode.com/method/weight-tying

# report number of parameters (note we don't count the decoder parameters in lm_head)
n_params = sum(p.numel() for p in self.transformer.parameters())
# report number of parameters
n_params = sum(p.numel() for p in self.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))

def forward(self, idx, targets=None):
Expand Down Expand Up @@ -156,8 +157,9 @@ def crop_block_size(self, block_size):
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]

@classmethod
def from_pretrained(cls, model_type, override_args):
def from_pretrained(cls, model_type, override_args=None):
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
override_args = override_args or {} # default to empty dict
# only dropout can be overridden see more notes below
assert all(k == 'dropout' for k in override_args)
from transformers import GPT2LMHeadModel
Expand Down Expand Up @@ -235,6 +237,14 @@ def configure_optimizers(self, weight_decay, learning_rate, betas):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)

# subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
# will appear in the no_decay and decay sets respectively after the above.
# In addition, because named_parameters() doesn't return duplicates, it
# will only return the first occurence, key'd by 'transformer.wte.weight', below.
# so let's manually remove 'lm_head.weight' from decay set. This will include
# this tensor into optimization via transformer.wte.weight only, and not decayed.
decay.remove('lm_head.weight')

# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
Expand Down

0 comments on commit 7c82885

Please sign in to comment.