Skip to content

Commit

Permalink
Fix unintended deprecation warning in torch.distributed.optim (pytorc…
Browse files Browse the repository at this point in the history
…h#140889)

We have a deprecation warning for scripted functional optimizer at module level in `torch/distributed/optim/__init__.py`. However, not all optimizers exposed by the module are scripted functional optimizers, causing some false deprecation warning (e.g. pytorch#139661).

This PR moves the deprecation warning to the `__init__` functions of the deprecated scripted functional optimizers.

Pull Request resolved: pytorch#140889
Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/XilunWu
  • Loading branch information
yifuwang authored and pytorchmergebot committed Nov 18, 2024
1 parent 137554c commit 3d26c08
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 10 deletions.
10 changes: 0 additions & 10 deletions torch/distributed/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,6 @@
from .utils import as_functional_optim


with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`TorchScript` support for functional optimizers is deprecated "
"and will be removed in a future PyTorch release. "
"Consider using the `torch.compile` optimizer instead.",
DeprecationWarning,
stacklevel=2,
)

# DistributedOptimizer imports torch.distributed.rpc names, so gate availability
# based on RPC being available.
if hasattr(torch._C, "_rpc_init"):
Expand Down
16 changes: 16 additions & 0 deletions torch/distributed/optim/_deprecation_warning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import warnings

import torch


@torch.jit.ignore # type: ignore[misc]
def _scripted_functional_optimizer_deprecation_warning(stacklevel: int = 0) -> None:
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`TorchScript` support for functional optimizers is deprecated "
"and will be removed in a future PyTorch release. "
"Consider using the `torch.compile` optimizer instead.",
DeprecationWarning,
stacklevel=stacklevel + 2,
)
4 changes: 4 additions & 0 deletions torch/distributed/optim/functional_adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch
import torch.optim._functional as F
from torch import Tensor
from torch.distributed.optim._deprecation_warning import (
_scripted_functional_optimizer_deprecation_warning,
)


__all__: List[str] = []
Expand Down Expand Up @@ -31,6 +34,7 @@ def __init__(
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
_scripted_functional_optimizer_deprecation_warning(stacklevel=2)
self.defaults = {
"lr": lr,
"rho": rho,
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/optim/functional_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch
import torch.optim._functional as F
from torch import Tensor
from torch.distributed.optim._deprecation_warning import (
_scripted_functional_optimizer_deprecation_warning,
)


__all__: List[str] = []
Expand Down Expand Up @@ -36,6 +39,7 @@ def __init__(
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
_scripted_functional_optimizer_deprecation_warning(stacklevel=2)
self.defaults = {
"lr": lr,
"lr_decay": lr_decay,
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/optim/functional_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch
import torch.optim._functional as F
from torch import Tensor
from torch.distributed.optim._deprecation_warning import (
_scripted_functional_optimizer_deprecation_warning,
)


__all__: List[str] = []
Expand Down Expand Up @@ -33,6 +36,7 @@ def __init__(
fused: bool = False,
_allow_empty_param_list: bool = False,
):
_scripted_functional_optimizer_deprecation_warning(stacklevel=2)
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/optim/functional_adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch
import torch.optim._functional as F
from torch import Tensor
from torch.distributed.optim._deprecation_warning import (
_scripted_functional_optimizer_deprecation_warning,
)


__all__: List[str] = []
Expand Down Expand Up @@ -31,6 +34,7 @@ def __init__(
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
_scripted_functional_optimizer_deprecation_warning(stacklevel=2)
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/optim/functional_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch
import torch.optim._functional as F
from torch import Tensor
from torch.distributed.optim._deprecation_warning import (
_scripted_functional_optimizer_deprecation_warning,
)


__all__: List[str] = []
Expand Down Expand Up @@ -33,6 +36,7 @@ def __init__(
fused: bool = False,
_allow_empty_param_list: bool = False,
):
_scripted_functional_optimizer_deprecation_warning(stacklevel=2)
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/optim/functional_rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch
import torch.optim._functional as F
from torch import Tensor
from torch.distributed.optim._deprecation_warning import (
_scripted_functional_optimizer_deprecation_warning,
)


__all__: List[str] = []
Expand Down Expand Up @@ -33,6 +36,7 @@ def __init__(
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
_scripted_functional_optimizer_deprecation_warning(stacklevel=2)
self.defaults = {
"lr": lr,
"alpha": alpha,
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/optim/functional_rprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch
import torch.optim._functional as F
from torch import Tensor
from torch.distributed.optim._deprecation_warning import (
_scripted_functional_optimizer_deprecation_warning,
)


__all__: List[str] = []
Expand All @@ -30,6 +33,7 @@ def __init__(
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
_scripted_functional_optimizer_deprecation_warning(stacklevel=2)
self.defaults = {
"lr": lr,
}
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/optim/functional_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch
import torch.optim._functional as F
from torch import Tensor
from torch.distributed.optim._deprecation_warning import (
_scripted_functional_optimizer_deprecation_warning,
)


__all__: List[str] = []
Expand Down Expand Up @@ -33,6 +36,7 @@ def __init__(
fused: bool = False,
_allow_empty_param_list: bool = False,
):
_scripted_functional_optimizer_deprecation_warning(stacklevel=2)
self.defaults = {
"lr": lr,
"momentum": momentum,
Expand Down

0 comments on commit 3d26c08

Please sign in to comment.