Skip to content

Commit

Permalink
[zero] fix state_dict and load_state_dict for ddp ignored parameters (h…
Browse files Browse the repository at this point in the history
…pcaitech#2443)

* [ddp] add is_ddp_ignored

[ddp] rename to is_ddp_ignored

* [zero] fix state_dict and load_state_dict

* fix bugs

* [zero] update unit test for ZeroDDP
  • Loading branch information
1SAA authored Jan 11, 2023
1 parent 2731531 commit 5521af7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
22 changes: 18 additions & 4 deletions colossalai/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(self,
assert isinstance(p, ColoParameter)

if is_ddp_ignored(p):
p.data = p.data.half()
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue

fp32_data = p.data.float()
Expand Down Expand Up @@ -451,8 +451,14 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."

param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# TODO: (HELSON) deal with ddp ignored parameters
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
ddp_param_list = []
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
ddp_param_list.append((name, param))
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
if p is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
Expand Down Expand Up @@ -588,8 +594,16 @@ def load(param_name, dest_tensor, copy_func):
def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())

ddp_param_list = []
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
load(name, param, param.copy_)
else:
ddp_param_list.append((name, param))

fp32_to_name = dict()
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
if p is not None:
fp32_to_name[fp32_p] = name

Expand Down
12 changes: 10 additions & 2 deletions tests/test_gemini/update/test_zeroddp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.testing import assert_close

import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
Expand All @@ -17,6 +18,13 @@
from tests.test_tensor.common_utils import debug_print, set_seed


def ignore_the_first_parameter(model: torch.nn.Module):
for name, param in model.named_parameters():
print(f"parameter `{name}` is set ignored")
ZeroDDP.set_params_to_ignore([param])
return


@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
@parameterize('model_name', ['gpt2', 'bert'])
Expand Down Expand Up @@ -47,7 +55,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
for key, value in torch_dict.items():
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)


@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
Expand Down Expand Up @@ -84,7 +92,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
for key, value in torch_dict.items():
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)


def run_dist(rank, world_size, port):
Expand Down

0 comments on commit 5521af7

Please sign in to comment.