Skip to content

Commit

Permalink
updated apex.contrib.optimizers.FP16_Optimizer and FusedSGD (NVIDIA#657)
Browse files Browse the repository at this point in the history
  • Loading branch information
kexinyu authored and mcarilli committed Dec 18, 2019
1 parent 4ad9b3b commit c19ee27
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 78 deletions.
122 changes: 44 additions & 78 deletions apex/contrib/optimizers/fp16_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from apex.multi_tensor_apply import multi_tensor_applier

class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
Designed only to wrap apex.optimizers.FusedAdam.
Designed only to wrap apex.contrib.optimizers.FusedAdam, FusedSGD.
Refer to apex.fp16_utils documents for more information.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = apex.optimizers.FusedAdam(model.parameters())
# Name the FP16_Optimizer instance to replace the existing optimizer
# (recommended but not required):
optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
# loss.backward() becomes:
optimizer.backward(loss)
...
Example with dynamic loss scaling::
...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
# optional arg to control dynamic loss scaling behavior
Expand All @@ -35,41 +29,36 @@ def __init__(self,
dynamic_loss_args=None,
verbose=True):

print("\nfp16_optimizer is designed to only work with apex.optimizers, and will be removed in future")
print("\nThis fp16_optimizer is designed to only work with apex.contrib.optimizers.*")
print("To update, use updated optimizers with AMP.")
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add new fused optimizer later

# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer

# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
self.fp32_groups_flat = []

# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# init fp16 weight buffer, flattened
self.fp16_groups_flat.append(_flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]]))
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
for p,q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
# init master weight, flattened
self.fp32_groups_flat.append(self.fp16_groups_flat[i].clone().float().detach())
# modify optimizer of have flat master weight
self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.fp32_groups_flat[i]]
self.fp16_groups = [] # model params
self.fp32_groups = [] # master weights

# iterate over param_groups
for param_group in self.optimizer.param_groups:
fp16_group = []
fp32_group = []
for p in param_group['params']:
fp16_group.append(p)
fp32_group.append(p.clone().float().detach())
self.fp16_groups.append(fp16_group)
self.fp32_groups.append(fp32_group)
param_group['params'] = fp32_group

if multi_tensor_applier.available:
import amp_C
self.overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
else:
raise RuntimeError('FP16_Optimizer requires cuda extensions')

# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
Expand Down Expand Up @@ -102,70 +91,47 @@ def zero_grad(self, set_grads_to_None=True):
p.grad.detach_()
p.grad.zero_()

def _compute_grad_norm(self, fp16_grads_flat, norm_type=2):
"""
Compute fp16 grad norm for later clipping(fused with update).
Internal accumulated in fp32.
Also fused in NaN check. Possibly other reduction needed for grad.
Args:
fp16_grads_flat (tensor): fp16 grad flattened
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the current fp16 gradients (viewed as a single vector).
Returns -1 if the most recently computed fp16 gradients overflowed
"""
# TODO: Not most efficient with copy to cpu and sync
# only support 2-norm now
# for torch version <= 1.0.1, torch.norm with dtype will fail and fall back to cast
try:
norm = float(torch.norm(fp16_grads_flat, 2.0, dtype=torch.float32))
except TypeError as err:
norm = float(torch.norm(fp16_grads_flat.float(), 2.0))
if norm == float('inf') or norm == -float('inf') or norm != norm:
return -1
else:
return norm

def step(self, closure=None):
"""
Not supporting closure.
"""
# First compute norm for all group so we know if there is overflow
grads_groups_flat = []
fp16_grads = []
norm_groups = []
skip = False
for i, group in enumerate(self.fp16_groups):
grads_groups_flat.append(_flatten_dense_tensors([p.grad for p in group]))
norm_groups.append(self._compute_grad_norm(grads_groups_flat[i]))
if norm_groups[i] == -1: #TODO: early break
skip = True

for group in self.fp16_groups:
fp16_grad = []
for i, p in enumerate(group):
fp16_grad.append(p.grad)
fp16_grads.append(fp16_grad)

# nan check
self.overflow_buf.zero_()
for fp16_grad in fp16_grads:
if len(fp16_grad) > 0:
norm, norm_per_tensor = multi_tensor_applier(self.multi_tensor_l2norm,
self.overflow_buf,
[fp16_grad], True)
norm_groups.append(norm)
if self.overflow_buf.item() != 0:
skip = True

if skip:
self._update_scale(skip)
return

# norm is in fact norm*cur_scale
self.optimizer.step(grads=[[g] for g in grads_groups_flat],
output_params=[[p] for p in self.fp16_groups_flat],
self.optimizer.step(grads=fp16_grads,
output_params=self.fp16_groups,
scale=self.cur_scale,
grad_norms=norm_groups)

# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
for p,q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

self._update_scale(False)
return

def backward(self, loss):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
Expand Down
211 changes: 211 additions & 0 deletions apex/contrib/optimizers/fused_sgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import types
import torch
from torch.optim.optimizer import Optimizer, required

from apex.multi_tensor_apply import multi_tensor_applier

class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
This version of fused SGD implements 2 fusions.
* Fusion of the SGD update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.contrib.optimizers.FusedSGD` should be used without AMP.
:class:`apex.contrib.optimizers.FusedSGD` only works in the case where all parameters require grad.
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
model = ...
model.half()
optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())
# wrap with FP16_Optimizer
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
optimizer.zero_grad()
...
optimizer.backward(loss)
optmizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""

def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(FusedSGD, self).__init__(params, defaults)

self.wd_after_momentum = wd_after_momentum

if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.contrib.optimizers.FusedSGD requires cuda extensions')

def __setstate__(self, state):
super(FusedSGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)

def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run

def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output_params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
if hasattr(self, "_amp_stash"):
raise RuntimeError('apex.contrib.optimizers.FusedSGD should not be used with AMP.')

loss = None
if closure is not None:
loss = closure()

if grads is None:
raise RuntimeError('apex.contrib.optimizers.FusedSGD must be wrapped \
with apex.contrib.optimizers.FP16_Optimizer \
which provides grads.')
# backward compatibility
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0])!=list:
grads_group = [grads]
else:
grads_group = grads

if output_params is None:
raise RuntimeError('apex.contrib.optimizers.FusedSGD must be wrapped \
with apex.contrib.optimizers.FP16_Optimizer \
which provides output_params.')
elif isinstance(output_params, types.GeneratorType):
output_params_group = [output_params]
elif type(output_params[0])!=list:
output_params_group = [output_params]
else:
output_params_group = output_params

for group, grads_this_group, output_params_this_group in zip(self.param_groups,
grads_group,
output_params_group):
if grads_this_group is None or output_params_this_group is None:
raise RuntimeError('apex.contrib.optimizers.FusedSGD only works \
when all parameters require grad.')

weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
lr = group['lr']

first_runs = [True, True]

# output_params_this_group: original weights (either fp16 or fp32)
# group['params']: master weights (fp32)

# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# fp32, fp32, fp32, No
fp32_grads = [g for (p, g) in zip(output_params_this_group, grads_this_group) if p.dtype == torch.float32]
fp32_params = [p2 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float32]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
fp32_set = [fp32_grads, fp32_params, fp32_momentums]

# fp16, fp32, fp32, Yes
fp16_grads = [g for (p, g) in zip(output_params_this_group, grads_this_group) if p.dtype == torch.float16]
fp32_from_fp16_params = [p2 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float16]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_params = [p1 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float16]
fp16_set = [fp16_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_params]

launch_sets = [fp16_set, fp32_set]

for launch_set, first_run in zip(launch_sets, first_runs):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
launch_set,
weight_decay,
momentum,
dampening,
lr,
nesterov,
first_run,
self.wd_after_momentum,
1.0/scale)

return loss

0 comments on commit c19ee27

Please sign in to comment.