forked from facebookresearch/ClassyVision
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtasks_classification_task_test.py
289 lines (224 loc) · 9.96 KB
/
tasks_classification_task_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import shutil
import tempfile
import unittest
from test.generic.config_utils import get_fast_test_task_config, get_test_task_config
from test.generic.utils import compare_model_state, compare_samples, compare_states
import torch
from classy_vision.dataset import build_dataset
from classy_vision.generic.distributed_util import is_distributed_training_run
from classy_vision.generic.util import get_checkpoint_dict
from classy_vision.hooks import CheckpointHook, LossLrMeterLoggingHook
from classy_vision.losses import ClassyLoss, build_loss, register_loss
from classy_vision.models import build_model
from classy_vision.optim import build_optimizer
from classy_vision.tasks import ClassificationTask, build_task
from classy_vision.trainer import LocalTrainer
@register_loss("test_stateful_loss")
class TestStatefulLoss(ClassyLoss):
def __init__(self, in_plane):
super(TestStatefulLoss, self).__init__()
self.alpha = torch.nn.Parameter(torch.Tensor(in_plane, 2))
torch.nn.init.xavier_normal(self.alpha)
@classmethod
def from_config(cls, config) -> "TestStatefulLoss":
return cls(in_plane=config["in_plane"])
def forward(self, output, target):
value = output.matmul(self.alpha)
loss = torch.mean(torch.abs(value))
return loss
class TestClassificationTask(unittest.TestCase):
def _compare_model_state(self, model_state_1, model_state_2, check_heads=True):
compare_model_state(self, model_state_1, model_state_2, check_heads)
def _compare_samples(self, sample_1, sample_2):
compare_samples(self, sample_1, sample_2)
def _compare_states(self, state_1, state_2, check_heads=True):
compare_states(self, state_1, state_2)
def setUp(self):
# create a base directory to write checkpoints to
self.base_dir = tempfile.mkdtemp()
def tearDown(self):
# delete all the temporary data created
shutil.rmtree(self.base_dir)
def test_build_task(self):
config = get_test_task_config()
task = build_task(config)
self.assertTrue(isinstance(task, ClassificationTask))
def test_hooks_config_builds_correctly(self):
config = get_test_task_config()
config["hooks"] = [{"name": "loss_lr_meter_logging"}]
task = build_task(config)
self.assertTrue(len(task.hooks) == 1)
self.assertTrue(isinstance(task.hooks[0], LossLrMeterLoggingHook))
def test_get_state(self):
config = get_test_task_config()
loss = build_loss(config["loss"])
task = (
ClassificationTask()
.set_num_epochs(1)
.set_loss(loss)
.set_model(build_model(config["model"]))
.set_optimizer(build_optimizer(config["optimizer"]))
)
for phase_type in ["train", "test"]:
dataset = build_dataset(config["dataset"][phase_type])
task.set_dataset(dataset, phase_type)
task.prepare()
task = build_task(config)
task.prepare()
def test_synchronize_losses_non_distributed(self):
"""
Tests that synchronize losses has no side effects in a non-distributed setting.
"""
test_config = get_fast_test_task_config()
task = build_task(test_config)
task.prepare()
old_losses = copy.deepcopy(task.losses)
task.synchronize_losses()
self.assertEqual(old_losses, task.losses)
def test_synchronize_losses_when_losses_empty(self):
config = get_fast_test_task_config()
task = build_task(config)
task.prepare()
task.set_use_gpu(torch.cuda.is_available())
# Losses should be empty when creating task
self.assertEqual(len(task.losses), 0)
task.synchronize_losses()
def test_checkpointing(self):
"""
Tests checkpointing by running train_steps to make sure the train_steps
run the same way after loading from a checkpoint.
"""
config = get_fast_test_task_config()
task = build_task(config).set_hooks([LossLrMeterLoggingHook()])
task_2 = build_task(config).set_hooks([LossLrMeterLoggingHook()])
task.set_use_gpu(torch.cuda.is_available())
# prepare the tasks for the right device
task.prepare()
# test in both train and test mode
for _ in range(2):
task.advance_phase()
# set task's state as task_2's checkpoint
task_2._set_checkpoint_dict(get_checkpoint_dict(task, {}, deep_copy=True))
task_2.prepare()
# task 2 should have the same state
self._compare_states(task.get_classy_state(), task_2.get_classy_state())
# this tests that both states' iterators return the same samples
sample = next(task.get_data_iterator())
sample_2 = next(task_2.get_data_iterator())
self._compare_samples(sample, sample_2)
# test that the train step runs the same way on both states
# and the loss remains the same
task.train_step()
task_2.train_step()
self._compare_states(task.get_classy_state(), task_2.get_classy_state())
def test_final_train_checkpoint(self):
"""Test that a train phase checkpoint with a where of 1.0 can be loaded"""
config = get_fast_test_task_config()
task = build_task(config).set_hooks(
[CheckpointHook(self.base_dir, {}, phase_types=["train"])]
)
task_2 = build_task(config)
task.set_use_gpu(torch.cuda.is_available())
trainer = LocalTrainer()
trainer.train(task)
self.assertAlmostEqual(task.where, 1.0, delta=1e-3)
# set task_2's state as task's final train checkpoint
task_2.set_checkpoint(self.base_dir)
task_2.prepare()
# we should be able to train the task
trainer.train(task_2)
def test_test_only_checkpointing(self):
"""
Tests checkpointing by running train_steps to make sure the
train_steps run the same way after loading from a training
task checkpoint on a test_only task.
"""
train_config = get_fast_test_task_config()
train_config["num_epochs"] = 10
test_config = get_fast_test_task_config()
test_config["test_only"] = True
train_task = build_task(train_config).set_hooks([LossLrMeterLoggingHook()])
test_only_task = build_task(test_config).set_hooks([LossLrMeterLoggingHook()])
# prepare the tasks for the right device
train_task.prepare()
# test in both train and test mode
trainer = LocalTrainer()
trainer.train(train_task)
# set task's state as task_2's checkpoint
test_only_task._set_checkpoint_dict(
get_checkpoint_dict(train_task, {}, deep_copy=True)
)
test_only_task.prepare()
test_state = test_only_task.get_classy_state()
# We expect the phase idx to be different for a test only task
self.assertEqual(test_state["phase_idx"], -1)
# We expect that test only state is test, no matter what train state is
self.assertFalse(test_state["train"])
# Num updates should be 0
self.assertEqual(test_state["num_updates"], 0)
# train_phase_idx should -1
self.assertEqual(test_state["train_phase_idx"], -1)
# Verify task will run
trainer = LocalTrainer()
trainer.train(test_only_task)
def test_test_only_task(self):
"""
Tests the task in test mode by running train_steps
to make sure the train_steps run as expected on a
test_only task
"""
test_config = get_fast_test_task_config()
test_config["test_only"] = True
# delete train dataset
del test_config["dataset"]["train"]
test_only_task = build_task(test_config).set_hooks([LossLrMeterLoggingHook()])
test_only_task.prepare()
test_state = test_only_task.get_classy_state()
# We expect that test only state is test, no matter what train state is
self.assertFalse(test_state["train"])
# Num updates should be 0
self.assertEqual(test_state["num_updates"], 0)
# Verify task will run
trainer = LocalTrainer()
trainer.train(test_only_task)
def test_train_only_task(self):
"""
Tests that the task runs when only a train dataset is specified.
"""
test_config = get_fast_test_task_config()
# delete the test dataset from the config
del test_config["dataset"]["test"]
task = build_task(test_config).set_hooks([LossLrMeterLoggingHook()])
task.prepare()
# verify the the task can still be trained
trainer = LocalTrainer()
trainer.train(task)
@unittest.skipUnless(torch.cuda.is_available(), "This test needs a gpu to run")
def test_checkpointing_different_device(self):
config = get_fast_test_task_config()
task = build_task(config)
task_2 = build_task(config)
for use_gpu in [True, False]:
task.set_use_gpu(use_gpu)
task.prepare()
# set task's state as task_2's checkpoint
task_2._set_checkpoint_dict(get_checkpoint_dict(task, {}, deep_copy=True))
# we should be able to run the trainer using state from a different device
trainer = LocalTrainer()
task_2.set_use_gpu(not use_gpu)
trainer.train(task_2)
@unittest.skipUnless(
is_distributed_training_run(), "This test needs a distributed run"
)
def test_get_classy_state_on_loss(self):
config = get_fast_test_task_config()
config["loss"] = {"name": "test_stateful_loss", "in_plane": 256}
task = build_task(config)
task.prepare()
self.assertIn("alpha", task.get_classy_state()["loss"])