Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
marianne-omnius committed Jul 5, 2024
1 parent c5fca72 commit db24178
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 54 deletions.
20 changes: 11 additions & 9 deletions schedulefree/adamw_schedulefree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union, Tuple, Optional
import torch
import torch.optim
from torch.optim.optimizer import ParamsT
import math

class AdamWScheduleFree(torch.optim.Optimizer):
Expand Down Expand Up @@ -42,15 +44,15 @@ class AdamWScheduleFree(torch.optim.Optimizer):
usage (default True if supported in your PyTorch version).
"""
def __init__(self,
params,
lr=0.0025,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
warmup_steps=0,
r=0.0,
weight_lr_power=2.0,
foreach=hasattr(torch, "_foreach_mul_")
params: ParamsT,
lr: Union[float, torch.Tensor] = 0.0025,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
warmup_steps: int = 0,
r: float = 0.0,
weight_lr_power: float = 2.0,
foreach: Optional[bool] = hasattr(torch, "_foreach_mul_")
):

defaults = dict(lr=lr,
Expand Down
21 changes: 12 additions & 9 deletions schedulefree/adamw_schedulefree_closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union, Tuple, Optional
import torch.optim
from torch.optim.optimizer import ParamsT
import math

class AdamWScheduleFreeClosure(torch.optim.Optimizer):
Expand Down Expand Up @@ -39,15 +41,16 @@ class AdamWScheduleFreeClosure(torch.optim.Optimizer):
Should be significantly faster, but will have higher peak memory
usage (default True if supported in your PyTorch version).
"""
def __init__(self, params,
lr=0.0025,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
warmup_steps=0,
r=0,
weight_lr_power=2.0,
foreach=hasattr(torch, "_foreach_mul_")
def __init__(self,
params: ParamsT,
lr: Union[float, torch.Tensor] = 0.0025,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
warmup_steps: int = 0,
r: float = 0,
weight_lr_power: float = 2.0,
foreach: Optional[bool] = hasattr(torch, "_foreach_mul_")
):
defaults = dict(lr=lr,
betas=betas,
Expand Down
24 changes: 13 additions & 11 deletions schedulefree/adamw_schedulefree_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union, Tuple, Optional
import torch
import torch.optim
from torch.optim.optimizer import ParamsT
import math
import torch.distributed as dist

class AdamWScheduleFreeReference(torch.optim.Optimizer):
r"""
Expand Down Expand Up @@ -42,16 +43,17 @@ class AdamWScheduleFreeReference(torch.optim.Optimizer):
decay_at_z (bool): Apply weight decay calculated at the z sequence
instead of the y sequence (default False).
"""
def __init__(self, params,
lr=0.0025,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
warmup_steps=0,
r=0,
weight_lr_power=2,
decay_at_z=False,
foreach=False, # Ignored
def __init__(self,
params: ParamsT,
lr: Union[float, torch.Tensor] = 0.0025,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
warmup_steps: int = 0,
r: float = 0,
weight_lr_power: float = 2,
decay_at_z: bool = False,
foreach: Optional[bool] = False, # Ignored
):

defaults = dict(lr=lr,
Expand Down
18 changes: 10 additions & 8 deletions schedulefree/sgd_schedulefree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union, Optional
import torch
import torch.optim
from torch.optim.optimizer import ParamsT

class SGDScheduleFree(torch.optim.Optimizer):
r"""
Expand Down Expand Up @@ -38,14 +40,14 @@ class SGDScheduleFree(torch.optim.Optimizer):
usage (default True if supported in your PyTorch version).
"""
def __init__(self,
params,
lr=1.0,
momentum=0.9,
weight_decay=0,
warmup_steps=0,
r=0.0,
weight_lr_power=2,
foreach=hasattr(torch, "_foreach_mul_"),
params: ParamsT,
lr: Union[float, torch.Tensor] = 1.0,
momentum: float = 0.9,
weight_decay: float = 0,
warmup_steps: int = 0,
r: float = 0.0,
weight_lr_power: float = 2,
foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"),
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
Expand Down
36 changes: 19 additions & 17 deletions schedulefree/sgd_schedulefree_closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,33 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union, Optional
import torch
import torch.optim
from torch.optim.optimizer import ParamsT

class SGDScheduleFreeClosure(torch.optim.Optimizer):
r"""
Schedule-Free SGD
As the name suggests, no scheduler is needed with this optimizer.
As the name suggests, no scheduler is needed with this optimizer.
To add warmup, rather than using a learning rate schedule you can just
set the warmup_steps parameter.
This "closure" version requires that step be called with a closure, see
https://pytorch.org/docs/stable/optim.html#optimizer-step-closure.
This "closure" version requires that step be called with a closure, see
https://pytorch.org/docs/stable/optim.html#optimizer-step-closure.
Arguments:
params (iterable):
Iterable of parameters to optimize or dicts defining
params (iterable):
Iterable of parameters to optimize or dicts defining
parameter groups.
lr (float):
lr (float):
Learning rate parameter (default 1.0)
momentum (float): momentum factor, must be between 0 and 1 exclusive
(default: 0.9)
weight_decay (float):
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0).
warmup_steps (int): Enables a linear learning rate warmup (default 0).
r (float): Use polynomial weighting in the average
r (float): Use polynomial weighting in the average
with power r (default 0).
weight_lr_power (float): During warmup, the weights in the average will
be equal to lr raised to this power. Set to 0 for no weighting
Expand All @@ -37,14 +39,14 @@ class SGDScheduleFreeClosure(torch.optim.Optimizer):
usage (default True if supported in your PyTorch version).
"""
def __init__(self,
params,
lr=1.0,
weight_decay=0,
momentum=0.9,
warmup_steps=0,
r=0.0,
weight_lr_power=2.0,
foreach=hasattr(torch, "_foreach_mul_")
params: ParamsT,
lr: Union[float, torch.Tensor] = 1.0,
weight_decay: float = 0,
momentum: float = 0.9,
warmup_steps: int = 0,
r: float = 0.0,
weight_lr_power: float = 2.0,
foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"),
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
Expand Down

0 comments on commit db24178

Please sign in to comment.