Skip to content

Commit

Permalink
[optim] use lerp whenever possible (pytorch#104796)
Browse files Browse the repository at this point in the history
This is a better copy (with fixes) of pytorch#104781.

Test plan:
CI will pass once pytorch#104784 is landed

Pull Request resolved: pytorch#104796
Approved by: https://github.com/albanD
  • Loading branch information
janeyx99 authored and pytorchmergebot committed Jul 8, 2023
1 parent 5da4745 commit fbe2a7e
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 21 deletions.
3 changes: 0 additions & 3 deletions test/profiler/test_memory_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,7 +1519,6 @@ def id_for_testing(key):
create OPTIMIZER_STATE 22(v0) 1024 kB
create OPTIMIZER_STATE 23(v0) 1024 kB
increment_version OPTIMIZER_STATE 18(v0) 128 kB
increment_version OPTIMIZER_STATE 18(v1) 128 kB
increment_version OPTIMIZER_STATE 19(v0) 128 kB
increment_version OPTIMIZER_STATE 19(v1) 128 kB
create ??? 24(v0) 128 kB
Expand All @@ -1528,7 +1527,6 @@ def id_for_testing(key):
increment_version ??? 25(v0) 128 kB
increment_version PARAMETER 0(v0) 128 kB
increment_version OPTIMIZER_STATE 20(v0) 2 kB
increment_version OPTIMIZER_STATE 20(v1) 2 kB
increment_version OPTIMIZER_STATE 21(v0) 2 kB
increment_version OPTIMIZER_STATE 21(v1) 2 kB
create ??? 26(v0) 2 kB
Expand All @@ -1538,7 +1536,6 @@ def id_for_testing(key):
destroy ??? 25(v1) 128 kB
increment_version PARAMETER 1(v0) 2 kB
increment_version OPTIMIZER_STATE 22(v0) 1024 kB
increment_version OPTIMIZER_STATE 22(v1) 1024 kB
increment_version OPTIMIZER_STATE 23(v0) 1024 kB
increment_version OPTIMIZER_STATE 23(v1) 1024 kB
create ??? 28(v0) 1024 kB
Expand Down
5 changes: 2 additions & 3 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _single_tensor_adam(params: List[Tensor],
param = torch.view_as_real(param)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)

if capturable or differentiable:
Expand Down Expand Up @@ -467,8 +467,7 @@ def _multi_tensor_adam(params: List[Tensor],
device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)

# Decay the first and second moment running average coefficient
torch._foreach_mul_(device_exp_avgs, beta1)
torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1)
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)

torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)
Expand Down
5 changes: 2 additions & 3 deletions torch/optim/adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _single_tensor_adamax(
exp_inf = torch.view_as_real(exp_inf)

# Update biased first moment estimate.
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg.lerp_(grad, 1 - beta1)
# Update the exponentially weighted infinity norm.
norm_buf = torch.cat(
[exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], 0
Expand Down Expand Up @@ -321,8 +321,7 @@ def _multi_tensor_adamax(
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)

# Update biased first moment estimate.
torch._foreach_mul_(grouped_exp_avgs, beta1)
torch._foreach_add_(grouped_exp_avgs, grouped_grads, alpha=1 - beta1)
torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)

# Update the exponentially weighted infinity norm.
torch._foreach_mul_(grouped_exp_infs, beta2)
Expand Down
5 changes: 2 additions & 3 deletions torch/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _single_tensor_adamw(
param.mul_(1 - lr * weight_decay)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

if capturable or differentiable:
Expand Down Expand Up @@ -517,8 +517,7 @@ def _multi_tensor_adamw(
torch._foreach_mul_(device_params, 1 - lr * weight_decay)

# Decay the first and second moment running average coefficient
torch._foreach_mul_(device_exp_avgs, beta1)
torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1)
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)

torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)
Expand Down
5 changes: 2 additions & 3 deletions torch/optim/nadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _single_tensor_nadam(params: List[Tensor],
mu_product *= mu

# decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.div(bias_correction2).sqrt()

Expand Down Expand Up @@ -309,8 +309,7 @@ def _multi_tensor_nadam(params: List[Tensor],
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)

# Decay the first and second moment running average coefficient
torch._foreach_mul_(grouped_exp_avgs, beta1)
torch._foreach_add_(grouped_exp_avgs, grouped_grads, alpha=1 - beta1)
torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)

torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
torch._foreach_addcmul_(grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2)
Expand Down
5 changes: 2 additions & 3 deletions torch/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _single_tensor_radam(
grad = grad.add(param, alpha=weight_decay)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

# correcting bias for the first moving moment
Expand Down Expand Up @@ -337,8 +337,7 @@ def _multi_tensor_radam(
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)

# Decay the first and second moment running average coefficient
torch._foreach_mul_(grouped_exp_avgs, beta1)
torch._foreach_add_(grouped_exp_avgs, grouped_grads, alpha=1 - beta1)
torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)

torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
torch._foreach_addcmul_(grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2)
Expand Down
5 changes: 2 additions & 3 deletions torch/optim/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _single_tensor_rmsprop(
grad_avg = grad_avgs[i]
if is_complex_param:
grad_avg = torch.view_as_real(grad_avg)
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
grad_avg.lerp_(grad, 1 - alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
else:
avg = square_avg.sqrt()
Expand Down Expand Up @@ -348,8 +348,7 @@ def _view_complex_as_real(tensor_list):

if centered:
grouped_grad_avgs = _view_complex_as_real(grouped_grad_avgs)
torch._foreach_mul_(grouped_grad_avgs, alpha)
torch._foreach_add_(grouped_grad_avgs, grouped_grads, alpha=1 - alpha)
torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha)
avg = torch._foreach_addcmul(grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1)
torch._foreach_sqrt_(avg)
torch._foreach_add_(avg, eps)
Expand Down

0 comments on commit fbe2a7e

Please sign in to comment.