Skip to content

Commit

Permalink
Fixed Weight Decay Regularization in Adam
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelauli authored and myleott committed Jan 22, 2018
1 parent 66d9fcf commit ee36a6f
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
5 changes: 3 additions & 2 deletions fairseq/multiprocessing_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from fairseq import nccl, utils
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG
from fairseq.optim.nag import NAG
from fairseq.optim.adam import Adam


class MultiprocessingTrainer(MultiprocessingEventLoop):
Expand Down Expand Up @@ -95,7 +96,7 @@ def _build_optimizer(self):
'betas': eval(self.args.adam_betas),
'weight_decay': self.args.weight_decay,
}
return torch.optim.Adam(self.model.parameters(), **self._override_optim_state)
return Adam(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'nag':
self._override_optim_state = {
'lr': self.args.lr[0],
Expand Down
103 changes: 103 additions & 0 deletions fairseq/optim/adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import math
import torch
from torch.optim.optimizer import Optimizer


class Adam(Optimizer):
"""Implements Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(Adam, self).__init__(params, defaults)

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data)

p.data.addcdiv_(-step_size, exp_avg, denom)

return loss
File renamed without changes.

0 comments on commit ee36a6f

Please sign in to comment.