diff --git a/model.py b/model.py index bf178298b9..f18996c6be 100644 --- a/model.py +++ b/model.py @@ -277,7 +277,7 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): logits = logits[:, -1, :] / temperature # optionally crop the logits to only the top k options if top_k is not None: - v, _ = torch.topk(logits, top_k) + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1)