-
Notifications
You must be signed in to change notification settings - Fork 0
/
detector_utils.py
196 lines (157 loc) · 5.79 KB
/
detector_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from typing import List, Optional, Sequence, Union
import torch
import cv2
import numpy as np
from mmcv.ops import RoIPool
from mmengine.dataset import Compose, pseudo_collate
from mmengine.device import get_device
from mmengine.registry import init_default_scope
from mmdet.apis import inference_detector, init_detector
from mmdet.structures import DetDataSample, SampleList
from mmdet.utils import get_test_pipeline_cfg
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
def nms(dets: np.ndarray, thr: float):
"""Greedily select boxes with high confidence and overlap <= thr.
Args:
dets (np.ndarray): [[x1, y1, x2, y2, score]].
thr (float): Retain overlap < thr.
Returns:
list: Indexes to keep.
"""
if len(dets) == 0:
return []
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while len(order) > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thr)[0]
order = order[inds + 1]
return keep
def adapt_mmdet_pipeline(cfg):
"""Converts pipeline types in MMDetection's test dataloader to use the
'mmdet' namespace.
Args:
cfg (ConfigDict): Configuration dictionary for MMDetection.
Returns:
ConfigDict: Configuration dictionary with updated pipeline types.
"""
# use lazy import to avoid hard dependence on mmdet
from mmdet.datasets import transforms
if 'test_dataloader' not in cfg:
return cfg
pipeline = cfg.test_dataloader.dataset.pipeline
for trans in pipeline:
if trans['type'] in dir(transforms):
trans['type'] = 'mmdet.' + trans['type']
return cfg
def inference_detector(
model: torch.nn.Module,
imgs: ImagesType,
test_pipeline: Optional[Compose] = None,
text_prompt: Optional[str] = None,
custom_entities: bool = False,
) -> Union[DetDataSample, SampleList]:
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str, ndarray, Sequence[str/ndarray]):
Either image files or loaded images.
test_pipeline (:obj:`Compose`): Test pipeline.
Returns:
:obj:`DetDataSample` or list[:obj:`DetDataSample`]:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
"""
if isinstance(imgs, torch.Tensor):
if imgs.is_cuda:
imgs = imgs.cpu()
# Remove batch dimension and transpose
imgs = imgs.squeeze(0).permute(1, 2, 0).numpy()
# Ensure the data type is appropriate (uint8 for most image processing functions)
imgs = (imgs * 255).astype(np.uint8)
if isinstance(imgs, (list, tuple)) or (isinstance(imgs, np.ndarray) and len(imgs.shape) == 4):
is_batch = True
else:
imgs = [imgs]
is_batch = False
cfg = model.cfg
if test_pipeline is None:
cfg = cfg.copy()
test_pipeline = get_test_pipeline_cfg(cfg)
if isinstance(imgs[0], np.ndarray):
# Calling this method across libraries will result
# in module unregistered error if not prefixed with mmdet.
test_pipeline[0].type = "mmdet.LoadImageFromNDArray"
test_pipeline = Compose(test_pipeline)
if model.data_preprocessor.device.type == "cpu":
for m in model.modules():
assert not isinstance(
m, RoIPool
), "CPU inference with RoIPool is not supported currently."
result_list = []
for i, img in enumerate(imgs):
# prepare data
if isinstance(img, np.ndarray):
# TODO: remove img_id.
data_ = dict(img=img, img_id=0)
else:
# TODO: remove img_id.
data_ = dict(img_path=img, img_id=0)
if text_prompt:
data_["text"] = text_prompt
data_["custom_entities"] = custom_entities
# build the data pipeline
data_ = test_pipeline(data_)
data_["inputs"] = [data_["inputs"]]
data_["data_samples"] = [data_["data_samples"]]
# forward the model
with torch.no_grad(), torch.autocast(device_type=get_device(), dtype=torch.bfloat16):
results = model.test_step(data_)[0]
result_list.append(results)
if not is_batch:
return result_list[0]
else:
return result_list
def process_one_image_bbox(pred_instance, det_cat_id, bbox_thr, nms_thr):
bboxes = np.concatenate(
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
)
bboxes = bboxes[
np.logical_and(
pred_instance.labels == det_cat_id,
pred_instance.scores > bbox_thr,
)
]
bboxes = bboxes[nms(bboxes, nms_thr), :4]
return bboxes
def process_images_detector(imgs, detector):
"""Visualize predicted keypoints (and heatmaps) of one image."""
# predict bbox
det_results = inference_detector(detector, imgs)
pred_instances = list(
map(lambda det_result: det_result.pred_instances.numpy(), det_results)
)
bboxes_batch = list(
map(
lambda pred_instance: process_one_image_bbox(
pred_instance, 0, 0.3, 0.3 ## argparse.Namespace(det_cat_id=0, bbox_thr=0.3, nms_thr=0.3),
),
pred_instances,
)
)
return bboxes_batch