-
Notifications
You must be signed in to change notification settings - Fork 122
/
multi_stream_detector.py
59 lines (45 loc) · 2.18 KB
/
multi_stream_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import Dict
from mmdet.models import BaseDetector, TwoStageDetector
class MultiSteamDetector(BaseDetector):
def __init__(
self, model: Dict[str, TwoStageDetector], train_cfg=None, test_cfg=None
):
super(MultiSteamDetector, self).__init__()
self.submodules = list(model.keys())
for k, v in model.items():
setattr(self, k, v)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.inference_on = self.test_cfg.get("inference_on", self.submodules[0])
def model(self, **kwargs) -> TwoStageDetector:
if "submodule" in kwargs:
assert (
kwargs["submodule"] in self.submodules
), "Detector does not contain submodule {}".format(kwargs["submodule"])
model: TwoStageDetector = getattr(self, kwargs["submodule"])
else:
model: TwoStageDetector = getattr(self, self.inference_on)
return model
def freeze(self, model_ref: str):
assert model_ref in self.submodules
model = getattr(self, model_ref)
model.eval()
for param in model.parameters():
param.requires_grad = False
def forward_test(self, imgs, img_metas, **kwargs):
return self.model(**kwargs).forward_test(imgs, img_metas, **kwargs)
async def aforward_test(self, *, img, img_metas, **kwargs):
return self.model(**kwargs).aforward_test(img, img_metas, **kwargs)
def extract_feat(self, imgs):
return self.model().extract_feat(imgs)
async def aforward_test(self, *, img, img_metas, **kwargs):
return self.model(**kwargs).aforward_test(img, img_metas, **kwargs)
def aug_test(self, imgs, img_metas, **kwargs):
return self.model(**kwargs).aug_test(imgs, img_metas, **kwargs)
def simple_test(self, img, img_metas, **kwargs):
return self.model(**kwargs).simple_test(img, img_metas, **kwargs)
async def async_simple_test(self, img, img_metas, **kwargs):
return self.model(**kwargs).async_simple_test(img, img_metas, **kwargs)
def show_result(self, *args, **kwargs):
self.model().CLASSES = self.CLASSES
return self.model().show_result(*args, **kwargs)