Skip to content

Commit

Permalink
Delayed init
Browse files Browse the repository at this point in the history
  • Loading branch information
thorjohnsen committed May 31, 2020
1 parent 4c54fd2 commit 53cfd8c
Showing 1 changed file with 56 additions and 23 deletions.
79 changes: 56 additions & 23 deletions apex/contrib/optimizers/distributed_fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,30 +65,57 @@ class DistributedFusedLAMB(torch.optim.Optimizer):

def __init__(self, params,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False,
adam_w_mode=True, use_nvlamb=False, use_mt=False,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False):
global fused_adam_cuda, distributed_lamb_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")

self._amp_scale_adjustment = amp_scale_adjustment

if use_mt:
raise RuntimeError('DistributedFusedLAMB does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedLAMB does not support the AMSGrad variant.')

defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(DistributedFusedLAMB, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1

self._init_args = {
'lr': lr,
'bias_correction': bias_correction,
'grad_averaging': grad_averaging,
'betas': betas,
'eps': eps,
'weight_decay': weight_decay,
'max_grad_norm': max_grad_norm,
'adam_w_mode': adam_w_mode,
'use_nvlamb': use_nvlamb,
'amp_scale_adjustment': amp_scale_adjustment,
'overlap_reductions': overlap_reductions,
'dwu_group_size': dwu_group_size,
'dwu_num_blocks': dwu_num_blocks,
'dwu_num_chunks': dwu_num_chunks,
'dwu_num_rs_pg': dwu_num_rs_pg,
'dwu_num_ar_pg': dwu_num_ar_pg,
'dwu_num_ag_pg': dwu_num_ag_pg,
'e5m2_allgather': e5m2_allgather}
self._init_done = False

import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"

def __init_on_demand(self,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False):
global fused_adam_cuda, distributed_lamb_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")

self._amp_scale_adjustment = amp_scale_adjustment

self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
Expand Down Expand Up @@ -362,8 +389,10 @@ def __packed_chunkify(p):
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks

import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def _first_time_init(self):
if not self._init_done:
self.__init_on_demand(**self._init_args)
self._init_done = False

def set_is_accumulation_step(self, is_accumulation_step):
self._is_accumulation_step = is_accumulation_step
Expand Down Expand Up @@ -466,6 +495,7 @@ def _pipeline_step(self):
self.L2_grad_norm,
max_grad_norm)
upd_norm = self.__compute_contrib_update_norm()
print(self.L2_grad_norm,max_grad_norm,param_norm,upd_norm)
multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf,
self._contrib_update_weights_tensor_list, # u, p, p_copy
Expand All @@ -482,19 +512,20 @@ def _flatten_grad_mt(self, scale):
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(self._grads_fp16)),
list(zip(*self._grads_fp16)),
scale)
self._grads_fp16 = []
if len(self._grads_fp32) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(self._grads_fp32)),
list(zip(*self._grads_fp32)),
scale)
self._grads_fp32 = []

def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
self._first_time_init()
if not self._is_accumulation_step:
# handle overlapped reductions
if param.dtype == torch.float16:
Expand All @@ -518,6 +549,7 @@ def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""

self._first_time_init()
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
Expand Down Expand Up @@ -545,10 +577,11 @@ def step(self, closure=None):

# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in self._param_group:
self._param_group['step'] += 1
else:
self._param_group['step'] = 1
for param_group in self.param_groups:
if 'step' in param_group:
param_group['step'] += 1
else:
param_group['step'] = 1

self._pipeline_step()

Expand Down

0 comments on commit 53cfd8c

Please sign in to comment.