-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathtest_backbones.py
executable file
·55 lines (44 loc) · 1.87 KB
/
test_backbones.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import copy
import torch
# import modules to to register backbones
from maskrcnn_benchmark.modeling.backbone import build_backbone # NoQA
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.config import cfg as g_cfg
from utils import load_config
# overwrite configs if specified, otherwise default config is used
BACKBONE_CFGS = {
"R-50-FPN": "e2e_faster_rcnn_R_50_FPN_1x.yaml",
"R-101-FPN": "e2e_faster_rcnn_R_101_FPN_1x.yaml",
"R-152-FPN": "e2e_faster_rcnn_R_101_FPN_1x.yaml",
"R-50-FPN-RETINANET": "retinanet/retinanet_R-50-FPN_1x.yaml",
"R-101-FPN-RETINANET": "retinanet/retinanet_R-101-FPN_1x.yaml",
}
class TestBackbones(unittest.TestCase):
def test_build_backbones(self):
''' Make sure backbones run '''
self.assertGreater(len(registry.BACKBONES), 0)
for name, backbone_builder in registry.BACKBONES.items():
print('Testing {}...'.format(name))
if name in BACKBONE_CFGS:
cfg = load_config(BACKBONE_CFGS[name])
else:
# Use default config if config file is not specified
cfg = copy.deepcopy(g_cfg)
backbone = backbone_builder(cfg)
# make sures the backbone has `out_channels`
self.assertIsNotNone(
getattr(backbone, 'out_channels', None),
'Need to provide out_channels for backbone {}'.format(name)
)
N, C_in, H, W = 2, 3, 224, 256
input = torch.rand([N, C_in, H, W], dtype=torch.float32)
out = backbone(input)
for cur_out in out:
self.assertEqual(
cur_out.shape[:2],
torch.Size([N, backbone.out_channels])
)
if __name__ == "__main__":
unittest.main()