forked from tianzhi0549/FCOS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheckpoint.py
118 lines (100 loc) · 4.54 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from collections import OrderedDict
import os
from tempfile import TemporaryDirectory
import unittest
import torch
from torch import nn
from maskrcnn_benchmark.utils.model_serialization import load_state_dict
from maskrcnn_benchmark.utils.checkpoint import Checkpointer
class TestCheckpointer(unittest.TestCase):
def create_model(self):
return nn.Sequential(nn.Linear(2, 3), nn.Linear(3, 1))
def create_complex_model(self):
m = nn.Module()
m.block1 = nn.Module()
m.block1.layer1 = nn.Linear(2, 3)
m.layer2 = nn.Linear(3, 2)
m.res = nn.Module()
m.res.layer2 = nn.Linear(3, 2)
state_dict = OrderedDict()
state_dict["layer1.weight"] = torch.rand(3, 2)
state_dict["layer1.bias"] = torch.rand(3)
state_dict["layer2.weight"] = torch.rand(2, 3)
state_dict["layer2.bias"] = torch.rand(2)
state_dict["res.layer2.weight"] = torch.rand(2, 3)
state_dict["res.layer2.bias"] = torch.rand(2)
return m, state_dict
def test_from_last_checkpoint_model(self):
# test that loading works even if they differ by a prefix
for trained_model, fresh_model in [
(self.create_model(), self.create_model()),
(nn.DataParallel(self.create_model()), self.create_model()),
(self.create_model(), nn.DataParallel(self.create_model())),
(
nn.DataParallel(self.create_model()),
nn.DataParallel(self.create_model()),
),
]:
with TemporaryDirectory() as f:
checkpointer = Checkpointer(
trained_model, save_dir=f, save_to_disk=True
)
checkpointer.save("checkpoint_file")
# in the same folder
fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
self.assertTrue(fresh_checkpointer.has_checkpoint())
self.assertEqual(
fresh_checkpointer.get_checkpoint_file(),
os.path.join(f, "checkpoint_file.pth"),
)
_ = fresh_checkpointer.load()
for trained_p, loaded_p in zip(
trained_model.parameters(), fresh_model.parameters()
):
# different tensor references
self.assertFalse(id(trained_p) == id(loaded_p))
# same content
self.assertTrue(trained_p.equal(loaded_p))
def test_from_name_file_model(self):
# test that loading works even if they differ by a prefix
for trained_model, fresh_model in [
(self.create_model(), self.create_model()),
(nn.DataParallel(self.create_model()), self.create_model()),
(self.create_model(), nn.DataParallel(self.create_model())),
(
nn.DataParallel(self.create_model()),
nn.DataParallel(self.create_model()),
),
]:
with TemporaryDirectory() as f:
checkpointer = Checkpointer(
trained_model, save_dir=f, save_to_disk=True
)
checkpointer.save("checkpoint_file")
# on different folders
with TemporaryDirectory() as g:
fresh_checkpointer = Checkpointer(fresh_model, save_dir=g)
self.assertFalse(fresh_checkpointer.has_checkpoint())
self.assertEqual(fresh_checkpointer.get_checkpoint_file(), "")
_ = fresh_checkpointer.load(os.path.join(f, "checkpoint_file.pth"))
for trained_p, loaded_p in zip(
trained_model.parameters(), fresh_model.parameters()
):
# different tensor references
self.assertFalse(id(trained_p) == id(loaded_p))
# same content
self.assertTrue(trained_p.equal(loaded_p))
def test_complex_model_loaded(self):
for add_data_parallel in [False, True]:
model, state_dict = self.create_complex_model()
if add_data_parallel:
model = nn.DataParallel(model)
load_state_dict(model, state_dict)
for loaded, stored in zip(model.state_dict().values(), state_dict.values()):
# different tensor references
self.assertFalse(id(loaded) == id(stored))
# same content
self.assertTrue(loaded.equal(stored))
if __name__ == "__main__":
unittest.main()