-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathtest_predictors.py
executable file
·98 lines (80 loc) · 3.14 KB
/
test_predictors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import copy
import torch
# import modules to to register predictors
from maskrcnn_benchmark.modeling.backbone import build_backbone # NoQA
from maskrcnn_benchmark.modeling.roi_heads.roi_heads import build_roi_heads # NoQA
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.config import cfg as g_cfg
from utils import load_config
# overwrite configs if specified, otherwise default config is used
PREDICTOR_CFGS = {
}
# overwrite configs if specified, otherwise default config is used
PREDICTOR_INPUT_CHANNELS = {
}
def _test_predictors(
self, predictors, overwrite_cfgs, overwrite_in_channels,
hwsize,
):
''' Make sure predictors run '''
self.assertGreater(len(predictors), 0)
in_channels_default = 64
for name, builder in predictors.items():
print('Testing {}...'.format(name))
if name in overwrite_cfgs:
cfg = load_config(overwrite_cfgs[name])
else:
# Use default config if config file is not specified
cfg = copy.deepcopy(g_cfg)
in_channels = overwrite_in_channels.get(
name, in_channels_default)
fe = builder(cfg, in_channels)
N, C_in, H, W = 2, in_channels, hwsize, hwsize
input = torch.rand([N, C_in, H, W], dtype=torch.float32)
out = fe(input)
yield input, out, cfg
class TestPredictors(unittest.TestCase):
def test_roi_box_predictors(self):
''' Make sure roi box predictors run '''
for cur_in, cur_out, cur_cfg in _test_predictors(
self,
registry.ROI_BOX_PREDICTOR,
PREDICTOR_CFGS,
PREDICTOR_INPUT_CHANNELS,
hwsize=1,
):
self.assertEqual(len(cur_out), 2)
scores, bbox_deltas = cur_out[0], cur_out[1]
self.assertEqual(
scores.shape[1], cur_cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES)
self.assertEqual(scores.shape[0], cur_in.shape[0])
self.assertEqual(scores.shape[0], bbox_deltas.shape[0])
self.assertEqual(scores.shape[1] * 4, bbox_deltas.shape[1])
def test_roi_keypoints_predictors(self):
''' Make sure roi keypoint predictors run '''
for cur_in, cur_out, cur_cfg in _test_predictors(
self,
registry.ROI_KEYPOINT_PREDICTOR,
PREDICTOR_CFGS,
PREDICTOR_INPUT_CHANNELS,
hwsize=14,
):
self.assertEqual(cur_out.shape[0], cur_in.shape[0])
self.assertEqual(
cur_out.shape[1], cur_cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_CLASSES)
def test_roi_mask_predictors(self):
''' Make sure roi mask predictors run '''
for cur_in, cur_out, cur_cfg in _test_predictors(
self,
registry.ROI_MASK_PREDICTOR,
PREDICTOR_CFGS,
PREDICTOR_INPUT_CHANNELS,
hwsize=14,
):
self.assertEqual(cur_out.shape[0], cur_in.shape[0])
self.assertEqual(
cur_out.shape[1], cur_cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES)
if __name__ == "__main__":
unittest.main()