Skip to content

Commit

Permalink
use the new fused AdamW from pytorch nightly, if available
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Feb 3, 2023
1 parent 7d44bdf commit e170e40
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import math
import inspect
from dataclasses import dataclass

import torch
Expand Down Expand Up @@ -307,7 +308,10 @@ def configure_optimizers(self, weight_decay, learning_rate, betas):
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
# new PyTorch nightly has a new 'fused' option for AdamW that is much faster
extra_args = dict(fused=True) if 'fused' in inspect.signature(torch.optim.AdamW).parameters else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)

return optimizer

@torch.no_grad()
Expand Down

0 comments on commit e170e40

Please sign in to comment.