diff --git a/model.py b/model.py index a10c850058..f934ef198d 100644 --- a/model.py +++ b/model.py @@ -8,6 +8,7 @@ """ import math +import inspect from dataclasses import dataclass import torch @@ -307,7 +308,10 @@ def configure_optimizers(self, weight_decay, learning_rate, betas): {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] - optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + # new PyTorch nightly has a new 'fused' option for AdamW that is much faster + extra_args = dict(fused=True) if 'fused' in inspect.signature(torch.optim.AdamW).parameters else dict() + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + return optimizer @torch.no_grad()