Skip to content

Commit

Permalink
Distributed LAMB fixes (NVIDIA#1007)
Browse files Browse the repository at this point in the history
* add flag for DistributedAdam: step_support_amp_scaling

Co-authored-by: Kexin Yu <[email protected]>
Co-authored-by: Kexin Yu <kexinznzn�@gmail.com>
  • Loading branch information
3 people authored Dec 4, 2020
1 parent 3fe10b5 commit 8a80d47
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 24 deletions.
3 changes: 1 addition & 2 deletions apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float global_grad_norm,
const float max_global_grad_norm);
const float grad_scale);

void multi_tensor_lamb_update_weights_cuda(
int chunk_size,
Expand Down
29 changes: 12 additions & 17 deletions apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ struct DistOptLAMBStage1Functor
const MATH_T* per_tensor_epsilon,
adamMode_t mode,
const MATH_T* per_tensor_decay,
const MATH_T global_grad_norm,
const MATH_T max_global_grad_norm)
const float grad_scale)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
Expand All @@ -132,15 +131,13 @@ struct DistOptLAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

MATH_T clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : (MATH_T) 1.0;

MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta1[tensor_num];
MATH_T beta3 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta2[tensor_num];
MATH_T beta3 = 1 - beta1;
MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, (MATH_T) step);
beta2_correction = 1 - pow(beta2, (MATH_T) step);
beta1_correction = 1 - pow(beta1, step);
beta2_correction = 1 - pow(beta2, step);
} else {
beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0;
Expand Down Expand Up @@ -207,7 +204,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
Expand All @@ -218,7 +215,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
Expand Down Expand Up @@ -277,7 +274,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
Expand All @@ -288,7 +285,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
Expand Down Expand Up @@ -346,7 +343,7 @@ struct DistOptLAMBStage2Functor
{
MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != (MATH_T) 0.0 && param_norm != (MATH_T) 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
ratio = (update_norm != 0.0 && param_norm != 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
}

MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
Expand Down Expand Up @@ -434,8 +431,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float global_grad_norm,
const float max_global_grad_norm)
const float grad_scale)
{
using namespace at;

Expand All @@ -456,8 +452,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(),
(scalar_t_2) global_grad_norm,
(scalar_t_2) max_global_grad_norm); )))
grad_scale); )))

AT_CUDA_CHECK(cudaGetLastError());
}
Expand Down
36 changes: 31 additions & 5 deletions apex/contrib/optimizers/distributed_fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
clip_grad_norm (boolean, optional): whether to handle gradient clipping
(default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
Expand All @@ -67,7 +69,7 @@ def __init__(self, params,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
Expand All @@ -89,6 +91,7 @@ def __init__(self, params,
'max_grad_norm': max_grad_norm,
'adam_w_mode': adam_w_mode,
'use_nvlamb': use_nvlamb,
'clip_grad_norm': clip_grad_norm,
'amp_scale_adjustment': amp_scale_adjustment,
'overlap_reductions': overlap_reductions,
'dwu_group_size': dwu_group_size,
Expand All @@ -107,7 +110,7 @@ def __first_step_init__(self,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
Expand All @@ -127,9 +130,11 @@ def __first_step_init__(self,

self._adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb
self._clip_grad_norm = clip_grad_norm
self._is_accumulation_step = False
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather
Expand Down Expand Up @@ -468,17 +473,30 @@ def __compute_contrib_update_norm(self):
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm)
torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
l2_norm = torch.sqrt(l2_norm)
return l2_norm.masked_select(self._model_param_is_contrib)

def _pipeline_step(self):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self.global_scale
max_grad_norm = self.defaults['max_grad_norm']
global_grad_norm = self.L2_grad_norm
if self._clip_grad_norm and max_grad_norm > 0 and math.isfinite(global_grad_norm):
combined_scale = max_grad_norm / (global_grad_norm / self.global_scale + 1e-6)
combined_scale = self.global_scale / min(1, combined_scale)

# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
param_norm = self.__compute_contrib_param_norm()
max_grad_norm = self.defaults['max_grad_norm']
multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
self._overflow_buf,
self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
Expand All @@ -490,8 +508,7 @@ def _pipeline_step(self):
self._contrib_epsilon,
self._adam_w_mode,
self._contrib_weight_decay,
self.L2_grad_norm,
max_grad_norm)
combined_scale)
upd_norm = self.__compute_contrib_update_norm()
multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf,
Expand Down Expand Up @@ -537,6 +554,15 @@ def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, para
self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block()

def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale

@property
def global_scale(self):
return self._global_scale

@property
def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
Expand Down

0 comments on commit 8a80d47

Please sign in to comment.