Skip to content

Commit

Permalink
add learning rate schedulers (pytorch#1370)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiaming-Liu authored and soumith committed May 25, 2017
1 parent 0409b42 commit 630af4d
Show file tree
Hide file tree
Showing 3 changed files with 478 additions and 1 deletion.
18 changes: 18 additions & 0 deletions docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,21 @@ Algorithms
:members:
.. autoclass:: SGD
:members:

How to adjust Learning Rate
---------------------------

:mod:`torch.optim.lr_scheduler` provides several methods to adjust the learning
rate based on the number of epoches. :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`
allows dynamic learning rate reducing based on some validation measurements.

.. autoclass:: torch.optim.lr_scheduler.LambdaLR
:members:
.. autoclass:: torch.optim.lr_scheduler.StepLR
:members:
.. autoclass:: torch.optim.lr_scheduler.MultiStepLR
:members:
.. autoclass:: torch.optim.lr_scheduler.ExponentialLR
:members:
.. autoclass:: torch.optim.lr_scheduler.ReduceLROnPlateau
:members:
156 changes: 155 additions & 1 deletion test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import torch
import torch.optim as optim
import torch.legacy.optim as old_optim
import torch.nn.functional as F
from torch.optim import SGD
from torch.autograd import Variable
from torch import sparse

from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau
from common import TestCase, run_tests


Expand Down Expand Up @@ -392,5 +394,157 @@ def test_invalid_param_type(self):
optim.SGD(Variable(torch.randn(5, 5)), lr=3)


class SchedulerTestNet(torch.nn.Module):
def __init__(self):
super(SchedulerTestNet, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)

def forward(self, x):
return self.conv2(F.relu(self.conv1(x)))


class TestLRScheduler(TestCase):
def setUp(self):
self.net = SchedulerTestNet()
self.opt = SGD(
[{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
lr=0.05)

def test_step_lr(self):
# lr = 0.05 if epoch < 3
# lr = 0.005 if 30 <= epoch < 6
# lr = 0.0005 if epoch >= 9
single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3
targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
epochs = 10
self._test(scheduler, targets, epochs)

def test_multi_step_lr(self):
# lr = 0.05 if epoch < 2
# lr = 0.005 if 2 <= epoch < 5
# lr = 0.0005 if epoch < 9
# lr = 0.00005 if epoch >= 9
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
epochs = 10
self._test(scheduler, targets, epochs)

def test_exp_lr(self):
single_targets = [0.05 * (0.9 ** x) for x in range(10)]
targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
scheduler = ExponentialLR(self.opt, gamma=0.9)
epochs = 10
self._test(scheduler, targets, epochs)

def test_reduce_lr_on_plateau1(self):
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 20]
metrics = [10 - i * 0.0167 for i in range(20)]
scheduler = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min',
threshold=0.01, patience=5, cooldown=5)
epochs = 10
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

def test_reduce_lr_on_plateau2(self):
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2]
metrics = [10 - i * 0.0165 for i in range(22)]
scheduler = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs',
mode='min', threshold=0.1)
epochs = 22
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

def test_reduce_lr_on_plateau3(self):
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4]
metrics = [-0.8] * 2 + [-0.234] * 20
scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5,
threshold_mode='abs')
epochs = 22
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

def test_reduce_lr_on_plateau4(self):
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 20]
metrics = [1.5 * (1.025 ** i) for i in range(20)] # 1.025 > 1.1**0.25
scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=3,
threshold_mode='rel', threshold=0.1)
epochs = 20
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

def test_reduce_lr_on_plateau5(self):
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
metrics = [1.5 * (1.005 ** i) for i in range(20)]
scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel',
threshold=0.1, patience=5, cooldown=5)
epochs = 20
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

def test_reduce_lr_on_plateau6(self):
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 20]
metrics = [1.5 * (0.85 ** i) for i in range(20)]
scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel',
threshold=0.1)
epochs = 20
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

def test_reduce_lr_on_plateau7(self):
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
metrics = [1] * 7 + [0.6] + [0.5] * 12
scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel',
threshold=0.1, patience=5, cooldown=5)
epochs = 20
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

def test_reduce_lr_on_plateau8(self):
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14]
metrics = [1.5 * (1.005 ** i) for i in range(20)]
scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', min_lr=[0.4, 0.3],
threshold=0.1, patience=5, cooldown=5)
epochs = 20
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

def test_lambda_lr(self):
self.opt.param_groups[0]['lr'] = 0.05
self.opt.param_groups[1]['lr'] = 0.4
targets = [[0.05 * (0.9 ** x) for x in range(10)], [0.4 * (0.8 ** x) for x in range(10)]]
scheduler = LambdaLR(self.opt,
lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2])
epochs = 10
self._test(scheduler, targets, epochs)

def _test(self, scheduler, targets, epochs=10):
for epoch in range(epochs):
scheduler.step(epoch)
for param_group, target in zip(self.opt.param_groups, targets):
self.assertAlmostEqual(target[epoch], param_group['lr'],
msg='LR is wrong in epoch {}: expected {}, got {}'.format(
epoch, target[epoch], param_group['lr']), delta=1e-5)

def _test_reduce_lr_on_plateau(self, scheduler, targets, metrics, epochs=10, verbose=False):
for epoch in range(epochs):
scheduler.step(metrics[epoch])
if verbose:
print('epoch{}:\tlr={}'.format(epoch, self.opt.param_groups[0]['lr']))
for param_group, target in zip(self.opt.param_groups, targets):
self.assertAlmostEqual(target[epoch], param_group['lr'],
msg='LR is wrong in epoch {}: expected {}, got {}'.format(
epoch, target[epoch], param_group['lr']), delta=1e-5)

if __name__ == '__main__':
run_tests()
Loading

0 comments on commit 630af4d

Please sign in to comment.