Skip to content

Commit

Permalink
use prescaling for collective (NVIDIA#1157)
Browse files Browse the repository at this point in the history
  • Loading branch information
seryilmaz authored Sep 2, 2021
1 parent 1cb9c5c commit 2e98baa
Showing 1 changed file with 46 additions and 4 deletions.
50 changes: 46 additions & 4 deletions apex/contrib/optimizers/distributed_fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self, params,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
e5m2_allgather=False, verbose=False, clip_after_ar=True,
full_ar=False):
full_ar=False, fuse_scale=False):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
Expand Down Expand Up @@ -121,6 +121,7 @@ def __init__(self, params,
self._verbose = verbose
self._clip_after_ar = clip_after_ar
self._full_ar = full_ar
self._fuse_scale = fuse_scale
self._L2_grad_norm = None
self._fused_norm = fused_norm
self._current_process_group = c10d._get_default_group()
Expand Down Expand Up @@ -544,6 +545,17 @@ def _get_flush_block(self):

return flush_block

def _full_all_reduce_scale(self, block_id, scale):
works = [None]*self._num_chunks

for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
self._reductions_works[block_id] = works

def _full_all_reduce(self, block_id):
works = [None]*self._num_chunks

Expand All @@ -555,6 +567,29 @@ def _full_all_reduce(self, block_id):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works

def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True,op=torch.distributed.make_nccl_premul_sum((scale,)))

# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works

def _reduce_scatter_and_all_reduce(self, block_id):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
Expand Down Expand Up @@ -620,12 +655,19 @@ def _pipeline_block_reductions(self, block_id):
tmp = torch.cat(((self._one), (coeff)))
index = (coeff+1>coeff).int()
scale = tmp.index_select(0, index).half()/self._world_size
self._flat_grads.mul_(scale)
if not self._fuse_scale:
self._flat_grads.mul_(scale)

if self._full_ar:
self._full_all_reduce(block_id)
if self._fuse_scale:
self._full_all_reduce_scale(block_id, scale)
else:
self._full_all_reduce(block_id)
else:
self._reduce_scatter_and_all_reduce(block_id)
if self._fuse_scale:
self._reduce_scatter_and_all_reduce_scale(block_id, scale)
else:
self._reduce_scatter_and_all_reduce(block_id)

if block_id == 0:
for block_id in range(self._num_blocks):
Expand Down

0 comments on commit 2e98baa

Please sign in to comment.