Skip to content

Commit

Permalink
fix bug... if topk > vocab_size, torch.topk will throw error
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Jan 14, 2023
1 parent 57735f5 commit 91d0251
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, top_k)
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
Expand Down

0 comments on commit 91d0251

Please sign in to comment.