forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add fused mixed precision lamb optimizer. (NVIDIA#1237)
* Add fused mixed precision lamb optimizer. * Fix device usage in constructor. * Fix sending param_group tensor state to device. * Remove unneeded device set.
- Loading branch information
1 parent
aa756ce
commit 3c8f516
Showing
7 changed files
with
1,072 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
import torch | ||
from copy import deepcopy | ||
from itertools import chain | ||
from collections import defaultdict, abc as container_abcs | ||
|
||
from apex.multi_tensor_apply import multi_tensor_applier | ||
|
||
class FusedMixedPrecisionLamb(torch.optim.Optimizer): | ||
|
||
def __init__(self, params, lr=1e-3, step=0, bias_correction=True, | ||
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, | ||
amsgrad=False, adam_w_mode=True, | ||
grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False, | ||
reduced_precision_dtype=None): | ||
if amsgrad: | ||
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') | ||
|
||
# The learning rate (lr) and optimizer step (step) should be located on device | ||
# in order to faciliated device sync free execution | ||
defaults = dict(lr=torch.tensor(lr, dtype=torch.float32), | ||
step=torch.tensor([step], dtype=torch.int), | ||
bias_correction=bias_correction, | ||
betas=betas, eps=eps, weight_decay=weight_decay, | ||
grad_averaging=grad_averaging, | ||
max_grad_norm=max_grad_norm) | ||
tensor_state = ['lr', 'step'] | ||
super(FusedMixedPrecisionLamb, self).__init__(params, defaults) | ||
|
||
device = self.param_groups[0]['params'][0].device | ||
|
||
for idx,group in enumerate(self.param_groups): | ||
for item in tensor_state: | ||
self.param_groups[idx][item] = group[item].to(device=device) | ||
|
||
if multi_tensor_applier.available: | ||
import amp_C | ||
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm_mp | ||
# Skip buffer | ||
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=device) | ||
self.multi_tensor_lamb = amp_C.multi_tensor_lamb_mp | ||
else: | ||
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions') | ||
|
||
# Mixed Precision support | ||
self.reduced_precision_dtype = reduced_precision_dtype | ||
self.param_groups_full_precision = [] | ||
|
||
self._step_supports_amp_scaling = True | ||
self.adam_w_mode = 1 if adam_w_mode else 0 | ||
self.use_nvlamb = use_nvlamb | ||
|
||
# This method is overridden from the parent class because there is not a way to override | ||
# the nested function cast() that copies a saved piece of state to the device without | ||
# redundantly doing the copy. | ||
def load_state_dict(self, state_dict): | ||
r"""Loads the optimizer state. | ||
Args: | ||
state_dict (dict): optimizer state. Should be an object returned | ||
from a call to :meth:`state_dict`. | ||
""" | ||
# deepcopy, to be consistent with module API | ||
state_dict = deepcopy(state_dict) | ||
# Validate the state_dict | ||
groups = self.param_groups | ||
saved_groups = state_dict['param_groups'] | ||
|
||
if len(groups) != len(saved_groups): | ||
raise ValueError("loaded state dict has a different number of " | ||
"parameter groups") | ||
param_lens = (len(g['params']) for g in groups) | ||
saved_lens = (len(g['params']) for g in saved_groups) | ||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): | ||
raise ValueError("loaded state dict contains a parameter group " | ||
"that doesn't match the size of optimizer's group") | ||
|
||
# Update the state | ||
id_map = {old_id: p for old_id, p in | ||
zip(chain.from_iterable((g['params'] for g in saved_groups)), | ||
chain.from_iterable((g['params'] for g in groups)))} | ||
|
||
def cast(param, value): | ||
r"""Make a deep copy of value, casting all tensors to device of param.""" | ||
if isinstance(value, torch.Tensor): | ||
# The original version casted the saved value to the params dtype | ||
# This doesn't work for mixed precision Lamb where the momentum and | ||
# velocity are expected to be in full precision while the params are | ||
# in reduced precision | ||
value = value.to(value.device) | ||
return value | ||
elif isinstance(value, dict): | ||
return {k: cast(param, v) for k, v in value.items()} | ||
elif isinstance(value, container_abcs.Iterable): | ||
return type(value)(cast(param, v) for v in value) | ||
else: | ||
return value | ||
|
||
# Copy state assigned to params (and cast tensors to appropriate types). | ||
# State that is not assigned to params is copied as is (needed for | ||
# backward compatibility). | ||
state = defaultdict(dict) | ||
for k, v in state_dict['state'].items(): | ||
if k in id_map: | ||
param = id_map[k] | ||
state[param] = cast(param, v) | ||
else: | ||
state[k] = v | ||
|
||
# Update parameter groups, setting their 'params' value | ||
def update_group(group, new_group): | ||
new_group['params'] = group['params'] | ||
return new_group | ||
param_groups = [ | ||
update_group(g, ng) for g, ng in zip(groups, saved_groups)] | ||
self.__setstate__({'state': state, 'param_groups': param_groups}) | ||
|
||
def _setup_full_precision_params(self): | ||
for i, pg in enumerate(self.param_groups): | ||
param_list = pg['params'] | ||
self.param_groups_full_precision.append({ | ||
'params': [ | ||
p.clone().detach().to(dtype=torch.float32) | ||
if (self.reduced_precision_dtype is not None) and (p.dtype == self.reduced_precision_dtype) | ||
else None | ||
for p in param_list | ||
], | ||
}) | ||
|
||
# add_param_groups() is overridden because default items can be tensors. The | ||
# parent version does not clone the default item, so two param groups can | ||
# accidentally point to the same default item value where they can differ | ||
# given they are in separate groups. | ||
def add_param_group(self, param_group): | ||
super().add_param_group(param_group) | ||
for name, default in self.defaults.items(): | ||
if isinstance(default, torch.Tensor): | ||
self.param_groups[len(self.param_groups) - 1][name] = default.clone() | ||
|
||
@torch.no_grad() | ||
def step(self, closure=None, grad_scaler=None): | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
|
||
# The full precision params are set up in the first step of the optimizer | ||
# instead of in the constructor because the full precision params will get out | ||
# out of sync with the model params if DDP syncs the model params across devices | ||
# after the optimizer is constructed. | ||
if len(self.param_groups_full_precision) == 0 : | ||
self._setup_full_precision_params() | ||
|
||
# create separate grad lists for params | ||
grad_list = [] | ||
for gid,group in enumerate(self.param_groups): | ||
for pid,p in enumerate(group['params']): | ||
assert group['params'][0].dtype == p.dtype, \ | ||
"Error: Parameters are not of the identical type: {} != {}".format( | ||
group['params'][0].dtype, p.dtype) | ||
if p.grad is None: | ||
continue | ||
grad_list.append(p.grad) | ||
|
||
# Overflow check of gradients | ||
device = self.param_groups[0]["params"][0].device | ||
found_inf = ( | ||
grad_scaler._check_inf_per_device(self)[device] | ||
if grad_scaler is not None else torch.zeros((1,), device=device) | ||
) | ||
self._dummy_overflow_buf.copy_(found_inf) | ||
|
||
# Get unscale scale factor | ||
scale, inv_scale = None, None | ||
if grad_scaler: | ||
scale = grad_scaler._get_scale_async() | ||
inv_scale = scale.double().reciprocal().float() | ||
else: | ||
scale = torch.ones((1,), device=device) | ||
inv_scale = torch.ones((1,), device=device) | ||
|
||
# grad_norm is of scaled gradients. | ||
# So, multiply `max_grad_norm` by scale. | ||
max_grad_norm = self.defaults['max_grad_norm'] * scale | ||
grad_norm = multi_tensor_applier( | ||
self.multi_tensor_l2norm, | ||
self._dummy_overflow_buf, | ||
[grad_list], | ||
False, | ||
)[0] | ||
|
||
# Run LAMB optimization math | ||
for gid, (group, group_full) in enumerate(zip(self.param_groups, self.param_groups_full_precision)): | ||
bias_correction = 1 if group['bias_correction'] else 0 | ||
beta1, beta2 = group['betas'] | ||
grad_averaging = 1 if group['grad_averaging'] else 0 | ||
|
||
# assume same step across group now to simplify things | ||
# per parameter step can be easily support by making it tensor, or pass list into kernel | ||
group['step'] += (self._dummy_overflow_buf != 1).to(torch.int) | ||
|
||
state_lists = [ [], # (0) grads | ||
[], # (1) params | ||
[], # (2) momentum state | ||
[], # (3) velocity state | ||
] | ||
if self.reduced_precision_dtype is not None: | ||
state_lists.append([]) # (4) params reduced_dtype | ||
|
||
|
||
for p, p_full in zip(group['params'], group_full['params']): | ||
if p.grad is None: | ||
continue | ||
assert not p.grad.is_sparse | ||
|
||
state = self.state[p] | ||
# State initialization | ||
if len(state) == 0: | ||
dtype = p.dtype | ||
if self.reduced_precision_dtype is not None and p.dtype == self.reduced_precision_dtype : | ||
dtype = torch.float32 | ||
# Exponential moving average of gradient values | ||
state['exp_avg'] = torch.zeros_like(p.data, dtype=dtype) | ||
# Exponential moving average of gradient values | ||
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=dtype) | ||
|
||
if self.reduced_precision_dtype is not None : | ||
state_lists[0].append(p.grad.data) | ||
state_lists[1].append(p_full.data) | ||
state_lists[2].append(state['exp_avg']) | ||
state_lists[3].append(state['exp_avg_sq']) | ||
state_lists[4].append(p.data) | ||
else : | ||
state_lists[0].append(p.grad.data) | ||
state_lists[1].append(p.data) | ||
state_lists[2].append(state['exp_avg']) | ||
state_lists[3].append(state['exp_avg_sq']) | ||
|
||
multi_tensor_applier( | ||
self.multi_tensor_lamb, | ||
self._dummy_overflow_buf, | ||
state_lists, | ||
group['lr'], | ||
beta1, | ||
beta2, | ||
group['eps'], | ||
group['step'], | ||
bias_correction, | ||
group['weight_decay'], | ||
grad_averaging, | ||
self.adam_w_mode, | ||
grad_norm, | ||
max_grad_norm, | ||
self.use_nvlamb, | ||
found_inf, | ||
inv_scale) | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.