From 17a3c685b056922c03113b32fa3b32293affbb1b Mon Sep 17 00:00:00 2001 From: HELSON Date: Wed, 30 Nov 2022 10:40:31 +0800 Subject: [PATCH] [zero] fix unit-tests (#2039) --- tests/components_to_test/utils/executor.py | 7 +-- tests/test_gemini/test_mem_tracer.py | 2 +- tests/test_gemini/update/test_fwd_bwd.py | 10 +++- tests/test_gemini/update/test_optim.py | 69 +++++++++++----------- 4 files changed, 44 insertions(+), 44 deletions(-) diff --git a/tests/components_to_test/utils/executor.py b/tests/components_to_test/utils/executor.py index 0bb98f2775ce..e77152561e6c 100644 --- a/tests/components_to_test/utils/executor.py +++ b/tests/components_to_test/utils/executor.py @@ -1,7 +1,7 @@ import torch -def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tensor: +def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor: """run_fwd_bwd run fwd and bwd for the model @@ -10,7 +10,6 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens data (torch.Tensor): input data label (torch.Tensor): label criterion (Optional[Callable]): a function of criterion - use_init_ctx (bool, optional): whether the model is initialized under the contxt of ColoInitCtx. Defaults to False. Returns: torch.Tensor: loss of fwd @@ -23,8 +22,8 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens loss = model(data, label) loss = loss.float() - if use_init_ctx: - model.backward(loss) + if optimizer: + optimizer.backward(loss) else: loss.backward() return loss diff --git a/tests/test_gemini/test_mem_tracer.py b/tests/test_gemini/test_mem_tracer.py index af4abc1ecf13..cb95cc7831bd 100644 --- a/tests/test_gemini/test_mem_tracer.py +++ b/tests/test_gemini/test_mem_tracer.py @@ -33,7 +33,7 @@ def run_tracer(rank, world_size, port, use_grad_check=True): data = data.cuda() label = label.cuda() - run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) + run_fwd_bwd(model, data, label, criterion) model._ophook_list[0].print_non_model_data() diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py index ef2e59e43902..b57f603ef21f 100644 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -10,6 +10,8 @@ from colossalai.amp import convert_to_apex_amp from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use @@ -55,6 +57,8 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch chunk_manager = ChunkManager(config_dict) gemini_manager = GeminiManager(placement_policy, chunk_manager) model = ZeroDDP(model, gemini_manager, pin_memory=True) + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) pg = ProcessGroup() amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) @@ -71,9 +75,9 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch # after bwd param is grad for Gemini, due to the chunk reuse optimization. if i > 0: break - - torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False) - loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True) + input_ids, label = input_ids.cuda(), label.cuda() + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) assert torch.equal(torch_loss, loss) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index ec6299a3c7a0..89b9b433be70 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -6,6 +6,7 @@ import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp @@ -20,7 +21,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed +from tests.test_tensor.common_utils import debug_print, set_seed def check_param(model: ZeroDDP, torch_model: torch.nn.Module): @@ -35,27 +36,31 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key) + assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2) # 'gpt2', 'bert', TEST_MODELS = ['gpt2', 'bert'] -# TEST_MODELS = ['simple_net'] +EXAMPLE_MODELS = ['simple_net'] -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('placement_policy', ['cuda']) @parameterize('model_name', TEST_MODELS) def exam_model_step(placement_policy, model_name: str): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + with ColoInitContext(device=get_current_device()): model = model_builder() - - torch_model = model_builder().cuda() for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) + p.data.copy_(torch_p.data) world_size = torch.distributed.get_world_size() config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) @@ -70,12 +75,7 @@ def exam_model_step(placement_policy, model_name: str): model = ZeroDDP(model, gemini_manager, pin_memory=True) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) - - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) model.eval() torch_model.eval() @@ -84,15 +84,13 @@ def exam_model_step(placement_policy, model_name: str): for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break - + input_ids, label = input_ids.cuda(), label.cuda() zero_optim.zero_grad() torch_optim.zero_grad() - torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False) - loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True) - - assert torch.allclose(torch_loss, loss, rtol=1e-3, atol=1e-2), f"{torch_loss} vs {loss}" - # debug_print([0], zero_logits, torch_logits) + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + assert_close(torch_loss, loss) zero_optim.step() torch_optim.step() @@ -101,31 +99,29 @@ def exam_model_step(placement_policy, model_name: str): @parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('model_name', TEST_MODELS) +@parameterize('model_name', EXAMPLE_MODELS) def exam_tiny_example(placement_policy, model_name: str): - set_seed(42) + set_seed(2008) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=2) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + with ColoInitContext(device=get_current_device()): model = model_builder() - - torch_model = model_builder().cuda() for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) + p.data.copy_(torch_p.data) chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) gemini_manager = GeminiManager(placement_policy, chunk_manager) model = ZeroDDP(model, gemini_manager, pin_memory=True) - optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - model.eval() torch_model.eval() @@ -134,14 +130,15 @@ def exam_tiny_example(placement_policy, model_name: str): if i > 2: break + input_ids = input_ids.cuda() + label = label.cuda() + zero_optim.zero_grad() torch_optim.zero_grad() - torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False) - loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True) - - assert torch.allclose(torch_loss, loss, rtol=1e-3, atol=1e-2), f"{torch_loss} vs {loss}" - # debug_print([0], zero_logits, torch_logits) + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + assert_close(torch_loss, loss) zero_optim.step() torch_optim.step() @@ -165,4 +162,4 @@ def test_optim(world_size): if __name__ == '__main__': - test_optim(2) + test_optim(1)