Skip to content

Commit

Permalink
Add RAdamScheduleFree and RAdamScheduleFreeClosure
Browse files Browse the repository at this point in the history
  • Loading branch information
nhamanasu committed Nov 28, 2024
1 parent 0daa68b commit bc573d3
Show file tree
Hide file tree
Showing 5 changed files with 516 additions and 2 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ Authors: Aaron Defazio, Xingyu (Alice) Yang, Harsh Mehta, Konstantin Mishchenko,

``` pip install schedulefree ```

Primary implementations are `SGDScheduleFree` and `AdamWScheduleFree`. We also have `AdamWScheduleFreeReference` and `SGDScheduleFreeReference` versions which have a simplified implementation, but which use more memory. To combine with other optimizers, use the experimental ScheduleFreeWrapper version.
We provide several Schedule-Free optimizer implementations:
- `SGDScheduleFree` and `SGDScheduleFreeReference`: Schedule-free variants of SGD
- `AdamWScheduleFree` and `AdamWScheduleFreeReference`: Schedule-free variants of AdamW
- `RAdamScheduleFree`: Schedule-free variant of RAdam, which eliminates the need for both learning rate scheduling and warmup setting
- Experimental `ScheduleFreeWrapper` to combine with other optimizers

`ScheduleFreeReference` versions have a simplified implementation, but which use more memory. There are also `ScheduleFreeClosure` versions which can be used with PyTorch's optimizer step closures.

A [Jax implementation](https://optax.readthedocs.io/en/latest/api/contrib.html#schedule-free) is availiable as part of Optax.

Expand Down
2 changes: 2 additions & 0 deletions schedulefree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .radam_schedulefree_closure import RAdamScheduleFreeClosure
from .radam_schedulefree import RAdamScheduleFree
from .adamw_schedulefree_closure import AdamWScheduleFreeClosure
from .adamw_schedulefree import AdamWScheduleFree
from .adamw_schedulefree_reference import AdamWScheduleFreeReference
Expand Down
235 changes: 235 additions & 0 deletions schedulefree/radam_schedulefree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple, Union, Optional, Iterable, Dict, Callable, Any
from typing_extensions import TypeAlias
import torch
import torch.optim
try:
from torch.optim.optimizer import ParamsT
except ImportError:
ParamsT : TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
import math

class RAdamScheduleFree(torch.optim.Optimizer):
r"""
Schedule-Free RAdam
Neither warmup hyperparameter nor scheduler is needed with this optimizer.
This optimizer requires that .train() and .eval() be called before the
beginning of training and evaluation respectively. The optimizer should
also be placed in eval mode when saving checkpoints.
Arguments:
params (iterable):
Iterable of parameters to optimize or dicts defining
parameter groups.
lr (float):
Learning rate parameter (default 0.0025)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999)).
eps (float):
Term added to the denominator outside of the root operation to
improve numerical stability. (default: 1e-8).
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0).
r (float): Use polynomial weighting in the average
with power r (default 0).
weight_lr_power (float): During warmup, the weights in the average will
be equal to lr raised to this power. Set to 0 for no weighting
(default 2.0).
foreach (bool): Use a foreach-backed implementation of the optimizer.
Should be significantly faster, but will have higher peak memory
usage (default True if supported in your PyTorch version).
silent_sgd_phase (bool): If True, the optimizer will not use the first SGD phase of RAdam.
This means that the optimizer will not update model parameters during the early training
steps (e.g., < 5 when β_2 = 0.999), but just update the momentum values of the optimizer.
This helps stabilize training by ensuring smoother warmup behavior and more reliable
calculation of the moving average coefficient (`ckp1`). Recommended to set to True
(default True).
"""

def __init__(self,
params: ParamsT,
lr: Union[float, torch.Tensor] = 0.0025,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
r: float = 0.0,
weight_lr_power: float = 2.0,
foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"),
silent_sgd_phase: bool = True
):

defaults = dict(lr=lr,
betas=betas,
eps=eps,
r=r,
k=0,
train_mode=False,
weight_sum=0.0,
lr_max=-1.0,
scheduled_lr=0.0,
weight_lr_power=weight_lr_power,
weight_decay=weight_decay,
foreach=foreach,
silent_sgd_phase=silent_sgd_phase)
super().__init__(params, defaults)

@torch.no_grad()
def eval(self):
for group in self.param_groups:
train_mode = group["train_mode"]
beta1, _ = group["betas"]
if train_mode:
for p in group["params"]:
state = self.state[p]
if "z" in state:
# Set p to x
p.lerp_(end=state["z"].to(p.device), weight=1 - 1 / beta1)
group["train_mode"] = False

@torch.no_grad()
def train(self):
for group in self.param_groups:
train_mode = group["train_mode"]
beta1, _ = group["betas"]
if not train_mode:
for p in group["params"]:
state = self.state[p]
if "z" in state:
# Set p to y
p.lerp_(end=state["z"].to(p.device), weight=1 - beta1)
group["train_mode"] = True

@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
if not self.param_groups[0]["train_mode"]:
raise Exception(
"Optimizer was not in train mode when step is called. "
"Please insert .train() and .eval() calls on the "
"optimizer. See documentation for details."
)

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
eps = group["eps"]
beta1, beta2 = group["betas"]
decay = group["weight_decay"]
silent_sgd_phase = group["silent_sgd_phase"]
k = group["k"] # current steps
step = k + 1
r = group['r']
weight_lr_power = group['weight_lr_power']

beta2_t = beta2**step
bias_correction2 = 1 - beta2_t

# maximum length of the approximated SMA
rho_inf = 2 / (1 - beta2) - 1
# compute the length of the approximated SMA
rho_t = rho_inf - 2 * step * beta2_t / bias_correction2
rect = (
((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t)) ** 0.5
if rho_t > 4.0
else float(not silent_sgd_phase)
)

lr = group["lr"] * rect
group["scheduled_lr"] = lr # For logging purposes

lr_max = group["lr_max"] = max(lr, group["lr_max"])

weight = (step**r) * (lr_max**weight_lr_power)
weight_sum = group["weight_sum"] = group["weight_sum"] + weight

try:
ckp1 = weight / weight_sum
except ZeroDivisionError:
ckp1 = 0

adaptive_y_lr = lr * (beta1 * (1 - ckp1) - 1)
active_p = [p for p in group["params"] if p.grad is not None]

for p in active_p:
if "z" not in self.state[p]:
self.state[p]["z"] = torch.clone(p, memory_format=torch.preserve_format)
self.state[p]["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

if group["foreach"] and len(active_p) > 0:
y, grad, exp_avg_sq, z = zip(
*[(p, p.grad, self.state[p]["exp_avg_sq"], self.state[p]["z"]) for p in active_p]
)

# Decay the first and second moment running average coefficient
torch._foreach_mul_(exp_avg_sq, beta2)
torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)

if rho_t > 4.0:
# Adam step
denom = torch._foreach_div(exp_avg_sq, bias_correction2)
torch._foreach_sqrt_(denom)
torch._foreach_add_(denom, eps)

# Normalize grad in-place for memory efficiency
torch._foreach_div_(grad, denom)

# Weight decay calculated at y
if decay != 0:
torch._foreach_add_(grad, y, alpha=decay)

# These operations update y in-place,
# without computing x explicitly.
torch._foreach_lerp_(y, z, weight=ckp1)
torch._foreach_add_(y, grad, alpha=adaptive_y_lr)

# z step
torch._foreach_sub_(z, grad, alpha=lr)
else:
for p in active_p:
y = p # Notation to match theory
grad = p.grad

state = self.state[p]

z = state["z"]
exp_avg_sq = state["exp_avg_sq"]

exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

if rho_t > 4.0:
# Adam step
denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(eps)

# Reuse grad buffer for memory efficiency
grad_normalized = grad.div_(denom)
else:
# Fall back to SGD (or nothing)
grad_normalized = grad

# Weight decay calculated at y
if decay != 0:
grad_normalized.add_(y, alpha=decay)

# These operations update y in-place,
# without computing x explicitly.
y.lerp_(end=z, weight=ckp1)
y.add_(grad_normalized, alpha=adaptive_y_lr)

# z step
z.sub_(grad_normalized, alpha=lr)

group["k"] = k + 1
return loss
Loading

0 comments on commit bc573d3

Please sign in to comment.