forked from apple/corenet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_common_utils.py
48 lines (37 loc) · 1.47 KB
/
test_common_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import os
import torch
import torch.distributed as dist
from corenet.utils.common_utils import unwrap_model_fn
def check_models(
original_unwrapped_model: torch.nn.Module, model_after_unwrapping: torch.nn.Module
) -> None:
"""Helper function to test original and unwrapped models are the same."""
for layer_id in range(len(original_unwrapped_model)):
# for unwrapped models, we should be able to index them
assert repr(model_after_unwrapping[layer_id]) == repr(
original_unwrapped_model[layer_id]
)
def test_unwrap_model_fn():
"""Test for unwrap_model_fn"""
dummy_model = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.Linear(20, 40),
)
# test DataParallel wrapping
wrapped_model_dp = torch.nn.DataParallel(dummy_model)
unwrapped_model_dp = unwrap_model_fn(wrapped_model_dp)
check_models(dummy_model, unwrapped_model_dp)
# Initialize the distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "1234"
dist.init_process_group(backend="gloo", rank=0, world_size=1)
# test DDP wrapping
wrapped_model_ddp = torch.nn.parallel.DistributedDataParallel(dummy_model)
unwrapped_model_ddp = unwrap_model_fn(wrapped_model_ddp)
check_models(dummy_model, unwrapped_model_ddp)
# clean up DDP environment
dist.destroy_process_group()