forked from facebookresearch/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_checkpoint.py
48 lines (39 loc) · 1.62 KB
/
test_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import unittest
from collections import OrderedDict
import torch
from torch import nn
from detectron2.checkpoint.c2_model_loading import align_and_update_state_dicts
from detectron2.utils.logger import setup_logger
class TestCheckpointer(unittest.TestCase):
def setUp(self):
setup_logger()
def create_complex_model(self):
m = nn.Module()
m.block1 = nn.Module()
m.block1.layer1 = nn.Linear(2, 3)
m.layer2 = nn.Linear(3, 2)
m.res = nn.Module()
m.res.layer2 = nn.Linear(3, 2)
state_dict = OrderedDict()
state_dict["layer1.weight"] = torch.rand(3, 2)
state_dict["layer1.bias"] = torch.rand(3)
state_dict["layer2.weight"] = torch.rand(2, 3)
state_dict["layer2.bias"] = torch.rand(2)
state_dict["res.layer2.weight"] = torch.rand(2, 3)
state_dict["res.layer2.bias"] = torch.rand(2)
return m, state_dict
def test_complex_model_loaded(self):
for add_data_parallel in [False, True]:
model, state_dict = self.create_complex_model()
if add_data_parallel:
model = nn.DataParallel(model)
model_sd = model.state_dict()
align_and_update_state_dicts(model_sd, state_dict)
for loaded, stored in zip(model_sd.values(), state_dict.values()):
# different tensor references
self.assertFalse(id(loaded) == id(stored))
# same content
self.assertTrue(loaded.equal(stored))
if __name__ == "__main__":
unittest.main()