Skip to content

Commit

Permalink
inference time mini-optimization low-hanging fruit ty @jxtps for rais…
Browse files Browse the repository at this point in the history
…ing: when we are running inference we can apply lm_head on only the very last token
  • Loading branch information
karpathy committed Jan 12, 2023
1 parent e21cbf8 commit 8f85b83
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,15 @@ def forward(self, idx, targets=None):
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)

# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None

return logits, loss

Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_batch(split):
iter_num = 0
best_val_loss = 1e9

# model init
# model init. TODO: fix bug we should also propagate the correct vocab_size to the model_args
model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout)
if init_from == 'scratch':
# init a new model from scratch
Expand Down

0 comments on commit 8f85b83

Please sign in to comment.