Skip to content

Commit

Permalink
[zero] add warning for ignored parameters (hpcaitech#2446)
Browse files Browse the repository at this point in the history
  • Loading branch information
1SAA authored Jan 11, 2023
1 parent 3916341 commit 2bfeb24
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
9 changes: 7 additions & 2 deletions colossalai/gemini/chunk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
from colossalai.utils import is_ddp_ignored


def safe_div(a, b):
if a == 0:
return 0
return a / b


def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
search_range_mb: Optional[float] = None,
min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:

kwargs_dict = dict()

if hidden_dim:
Expand Down Expand Up @@ -50,7 +55,7 @@ def init_chunk_manager(model: nn.Module,
if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='',
flush=True)
dist.barrier()
Expand Down
15 changes: 13 additions & 2 deletions colossalai/nn/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import warnings
from enum import Enum
from typing import Any, Dict, Set, Tuple

Expand Down Expand Up @@ -78,8 +79,16 @@ def __init__(self,
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"

params_list = [p for p in module.parameters() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, module.fp32_params):
ddp_param_list = []
for name, param in module.named_parameters():
if is_ddp_ignored(param):
if param.requires_grad:
warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!")
else:
ddp_param_list.append(param)

for p, fp32_p in zip(ddp_param_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
Expand Down Expand Up @@ -290,6 +299,8 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter):
fake_params_list = list()

for param in group['params']:
if is_ddp_ignored(param):
continue
chunk16 = self.chunk_manager.get_chunk(param)
range_pair = get_range_pair(chunk16, param)
if range_pair[0] >= range_pair[1]:
Expand Down

0 comments on commit 2bfeb24

Please sign in to comment.