forked from tianzhi0549/FCOS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_rpn_heads.py
executable file
·62 lines (50 loc) · 1.95 KB
/
test_rpn_heads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import copy
import torch
# import modules to to register rpn heads
from maskrcnn_benchmark.modeling.backbone import build_backbone # NoQA
from maskrcnn_benchmark.modeling.rpn.rpn import build_rpn # NoQA
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.config import cfg as g_cfg
from utils import load_config
# overwrite configs if specified, otherwise default config is used
RPN_CFGS = {
}
class TestRPNHeads(unittest.TestCase):
def test_build_rpn_heads(self):
''' Make sure rpn heads run '''
self.assertGreater(len(registry.RPN_HEADS), 0)
in_channels = 64
num_anchors = 10
for name, builder in registry.RPN_HEADS.items():
print('Testing {}...'.format(name))
if name in RPN_CFGS:
cfg = load_config(RPN_CFGS[name])
else:
# Use default config if config file is not specified
cfg = copy.deepcopy(g_cfg)
rpn = builder(cfg, in_channels, num_anchors)
N, C_in, H, W = 2, in_channels, 24, 32
input = torch.rand([N, C_in, H, W], dtype=torch.float32)
LAYERS = 3
out = rpn([input] * LAYERS)
self.assertEqual(len(out), 2)
logits, bbox_reg = out
for idx in range(LAYERS):
self.assertEqual(
logits[idx].shape,
torch.Size([
input.shape[0], num_anchors,
input.shape[2], input.shape[3],
])
)
self.assertEqual(
bbox_reg[idx].shape,
torch.Size([
logits[idx].shape[0], num_anchors * 4,
logits[idx].shape[2], logits[idx].shape[3],
]),
)
if __name__ == "__main__":
unittest.main()