From 2bfeb24308aa8c55e7a2c8ea42eb87a680618b50 Mon Sep 17 00:00:00 2001 From: HELSON Date: Wed, 11 Jan 2023 15:30:09 +0800 Subject: [PATCH] [zero] add warning for ignored parameters (#2446) --- colossalai/gemini/chunk/utils.py | 9 +++++++-- colossalai/nn/optimizer/zero_optimizer.py | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/colossalai/gemini/chunk/utils.py b/colossalai/gemini/chunk/utils.py index 883022fe89b8..ebfdee778979 100644 --- a/colossalai/gemini/chunk/utils.py +++ b/colossalai/gemini/chunk/utils.py @@ -10,13 +10,18 @@ from colossalai.utils import is_ddp_ignored +def safe_div(a, b): + if a == 0: + return 0 + return a / b + + def init_chunk_manager(model: nn.Module, init_device: Optional[torch.device] = None, hidden_dim: Optional[int] = None, search_range_mb: Optional[float] = None, min_chunk_size_mb: Optional[float] = None, filter_exlarge_params: Optional[bool] = None) -> ChunkManager: - kwargs_dict = dict() if hidden_dim: @@ -50,7 +55,7 @@ def init_chunk_manager(model: nn.Module, if dist.get_rank() == 0: print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), "used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size), - "total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)), + "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), sep='', flush=True) dist.barrier() diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py index 3dd9d1e93b36..9f761efdb12c 100644 --- a/colossalai/nn/optimizer/zero_optimizer.py +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -1,4 +1,5 @@ import math +import warnings from enum import Enum from typing import Any, Dict, Set, Tuple @@ -78,8 +79,16 @@ def __init__(self, if self.clipping_flag: assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" - params_list = [p for p in module.parameters() if not is_ddp_ignored(p)] - for p, fp32_p in zip(params_list, module.fp32_params): + ddp_param_list = [] + for name, param in module.named_parameters(): + if is_ddp_ignored(param): + if param.requires_grad: + warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! " + "You should handle its optimizer update by yourself!") + else: + ddp_param_list.append(param) + + for p, fp32_p in zip(ddp_param_list, module.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) if chunk_16 not in self.chunk16_set: chunk_16.l2_norm_flag = self.clipping_flag @@ -290,6 +299,8 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter): fake_params_list = list() for param in group['params']: + if is_ddp_ignored(param): + continue chunk16 = self.chunk_manager.get_chunk(param) range_pair = get_range_pair(chunk16, param) if range_pair[0] >= range_pair[1]: