Skip to content

Commit

Permalink
Bug fix in update norm calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
thorjohnsen committed May 30, 2020
1 parent fb2d0f4 commit 56650eb
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions apex/contrib/optimizers/distributed_fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __packed_chunkify(p):
self._contrib_model_param_for_norm_fp16 = []
self._contrib_model_param_for_norm_fp32 = []
self._contrib_model_param_for_norm_is_fp16 = []
self._model_param_is_contrib = [False]*self._model_params_num
self._contrib_group_properties = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
Expand All @@ -267,6 +268,7 @@ def __packed_chunkify(p):
else:
self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
self._model_param_is_contrib[param_i] = True
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
Expand Down Expand Up @@ -427,8 +429,8 @@ def __compute_contrib_param_norm(self):
gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda')
gnorm.masked_scatter(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)
gnorm.masked_scatter(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)
gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)
gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)
elif self._contrib_model_param_for_norm_fp16 is not None:
gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
elif self._contrib_model_param_for_norm_fp32 is not None:
Expand All @@ -438,10 +440,9 @@ def __compute_contrib_param_norm(self):
def __compute_contrib_update_norm(self):
l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
contrib_l2_norm = l2_norm[self._contrib_min_param_i:self._contrib_max_param_i+1]
contrib_l2_norm.copy_(local_contrib_l2_norm)
l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm)
torch.distributed.allreduce(l2_norm, group=self._ag_pg[0])
return l2_norm
return l2_norm.masked_select(self._model_param_is_contrib)

def _pipeline_step(self):
# Call step kernel once per step
Expand Down

0 comments on commit 56650eb

Please sign in to comment.