forked from tianzhi0549/FCOS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_feature_extractors.py
executable file
·93 lines (76 loc) · 3 KB
/
test_feature_extractors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import copy
import torch
# import modules to to register feature extractors
from maskrcnn_benchmark.modeling.backbone import build_backbone # NoQA
from maskrcnn_benchmark.modeling.roi_heads.roi_heads import build_roi_heads # NoQA
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.config import cfg as g_cfg
from utils import load_config
# overwrite configs if specified, otherwise default config is used
FEATURE_EXTRACTORS_CFGS = {
}
# overwrite configs if specified, otherwise default config is used
FEATURE_EXTRACTORS_INPUT_CHANNELS = {
# in_channels was not used, load through config
"ResNet50Conv5ROIFeatureExtractor": 1024,
}
def _test_feature_extractors(
self, extractors, overwrite_cfgs, overwrite_in_channels
):
''' Make sure roi box feature extractors run '''
self.assertGreater(len(extractors), 0)
in_channels_default = 64
for name, builder in extractors.items():
print('Testing {}...'.format(name))
if name in overwrite_cfgs:
cfg = load_config(overwrite_cfgs[name])
else:
# Use default config if config file is not specified
cfg = copy.deepcopy(g_cfg)
in_channels = overwrite_in_channels.get(
name, in_channels_default)
fe = builder(cfg, in_channels)
self.assertIsNotNone(
getattr(fe, 'out_channels', None),
'Need to provide out_channels for feature extractor {}'.format(name)
)
N, C_in, H, W = 2, in_channels, 24, 32
input = torch.rand([N, C_in, H, W], dtype=torch.float32)
bboxes = [[1, 1, 10, 10], [5, 5, 8, 8], [2, 2, 3, 4]]
img_size = [384, 512]
box_list = BoxList(bboxes, img_size, "xyxy")
out = fe([input], [box_list] * N)
self.assertEqual(
out.shape[:2],
torch.Size([N * len(bboxes), fe.out_channels])
)
class TestFeatureExtractors(unittest.TestCase):
def test_roi_box_feature_extractors(self):
''' Make sure roi box feature extractors run '''
_test_feature_extractors(
self,
registry.ROI_BOX_FEATURE_EXTRACTORS,
FEATURE_EXTRACTORS_CFGS,
FEATURE_EXTRACTORS_INPUT_CHANNELS,
)
def test_roi_keypoints_feature_extractors(self):
''' Make sure roi keypoints feature extractors run '''
_test_feature_extractors(
self,
registry.ROI_KEYPOINT_FEATURE_EXTRACTORS,
FEATURE_EXTRACTORS_CFGS,
FEATURE_EXTRACTORS_INPUT_CHANNELS,
)
def test_roi_mask_feature_extractors(self):
''' Make sure roi mask feature extractors run '''
_test_feature_extractors(
self,
registry.ROI_MASK_FEATURE_EXTRACTORS,
FEATURE_EXTRACTORS_CFGS,
FEATURE_EXTRACTORS_INPUT_CHANNELS,
)
if __name__ == "__main__":
unittest.main()