Skip to content

Commit 9abbb09

Browse files
authored
adding basic bbox inference pipeline (airctic#860)
* adding basic bbox inference pipeline * quick fix * adding end2end_detect * adding e2e_detect to mmdet * adding tests * fixing bug
1 parent dbfc652 commit 9abbb09

File tree

12 files changed

+362
-45
lines changed

12 files changed

+362
-45
lines changed

icevision/models/inference.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
__all__ = ["process_bbox_predictions", "_end2end_detect"]
2+
3+
from icevision.imports import *
4+
from icevision.core import *
5+
from icevision.data import *
6+
from icevision.tfms.albumentations.albumentations_helpers import (
7+
get_size_without_padding,
8+
)
9+
from icevision.tfms.albumentations import albumentations_adapter
10+
11+
12+
def _end2end_detect(
13+
img: Union[PIL.Image.Image, Path, str],
14+
transforms: albumentations_adapter.Adapter,
15+
model: torch.nn.Module,
16+
class_map: ClassMap,
17+
detection_threshold: float = 0.5,
18+
predict_fn: Callable = None,
19+
):
20+
"""
21+
Run Object Detection inference (only `bboxes`) on a single image.
22+
23+
Parameters
24+
----------
25+
img: image to run inference on. Can be a string, Path or PIL.Image
26+
transforms: icevision albumentations transforms
27+
model: model to run inference with
28+
class_map: ClassMap with the available categories
29+
detection_threshold: confidence threshold below which boxes are discarded
30+
31+
Returns
32+
-------
33+
List of dicts with category, score and bbox coordinates adjusted to original image size and aspect ratio
34+
"""
35+
if isinstance(img, (str, Path)):
36+
img = PIL.Image.open(Path(img))
37+
38+
infer_ds = Dataset.from_images([np.array(img)], transforms, class_map=class_map)
39+
pred = predict_fn(model, infer_ds, detection_threshold=detection_threshold)[0]
40+
bboxes = process_bbox_predictions(pred, img, transforms.tfms_list)
41+
return bboxes
42+
43+
44+
def process_bbox_predictions(
45+
pred: Prediction,
46+
img: PIL.Image.Image,
47+
transforms: List[Any],
48+
) -> List[Dict[str, Any]]:
49+
"""
50+
Postprocess prediction.
51+
52+
Parameters
53+
----------
54+
pred: icevision prediction object
55+
img: original image, before any model-pre-processing done
56+
transforms: list of model-pre-processing transforms
57+
58+
Returns
59+
-------
60+
List of dicts with class, score and bbox coordinates
61+
"""
62+
bboxes = []
63+
for bbox, score, label in zip(
64+
pred.pred.detection.bboxes,
65+
pred.pred.detection.scores,
66+
pred.pred.detection.labels,
67+
):
68+
xmin, ymin, xmax, ymax = postprocess_bbox(
69+
img, bbox, transforms, pred.pred.height, pred.pred.width
70+
)
71+
result = {
72+
"class": label,
73+
"score": score,
74+
"bbox": [xmin, ymin, xmax, ymax],
75+
}
76+
bboxes.append(result)
77+
return bboxes
78+
79+
80+
def postprocess_bbox(
81+
img: PIL.Image.Image, bbox: BBox, transforms: List[Any], h_after: int, w_after: int
82+
) -> Tuple[int, int, int, int]:
83+
"""
84+
Post-process predicted bbox to adjust coordinates to input image size.
85+
86+
Parameters
87+
----------
88+
img: original image, before any model-pre-processing done
89+
bbox: predicted bbox
90+
transforms: list of model-pre-processing transforms
91+
h_after: height of image after model-pre-processing transforms
92+
w_after: width of image after model-pre-processing transforms
93+
94+
Returns
95+
-------
96+
Tuple with (xmin, ymin, xmax, ymax) rescaled and re-adjusted to match the original image size
97+
"""
98+
w_before, h_before = img.size
99+
h_after, w_after = get_size_without_padding(transforms, img, h_after, w_after)
100+
pad = np.abs(h_after - w_after) // 2
101+
102+
h_scale, w_scale = h_after / h_before, w_after / w_before
103+
if h_after < w_after:
104+
xmin, xmax, ymin, ymax = (
105+
int(bbox.xmin),
106+
int(bbox.xmax),
107+
int(bbox.ymin) - pad,
108+
int(bbox.ymax) - pad,
109+
)
110+
else:
111+
xmin, xmax, ymin, ymax = (
112+
int(bbox.xmin) - pad,
113+
int(bbox.xmax) - pad,
114+
int(bbox.ymin),
115+
int(bbox.ymax),
116+
)
117+
118+
xmin, xmax, ymin, ymax = (
119+
max(xmin, 0),
120+
min(xmax, w_after),
121+
max(ymin, 0),
122+
min(ymax, h_after),
123+
)
124+
xmin, xmax, ymin, ymax = (
125+
int(xmin / w_scale),
126+
int(xmax / w_scale),
127+
int(ymin / h_scale),
128+
int(ymax / h_scale),
129+
)
130+
131+
return xmin, ymin, xmax, ymax

icevision/models/mmdet/common/bbox/prediction.py

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"predict_from_dl",
44
"convert_raw_prediction",
55
"convert_raw_predictions",
6+
"end2end_detect",
67
]
78

89
from icevision.imports import *
@@ -13,6 +14,7 @@
1314
from icevision.models.mmdet.common.utils import *
1415
from icevision.models.mmdet.common.bbox.dataloaders import build_infer_batch
1516
from icevision.models.mmdet.common.utils import convert_background_from_last_to_zero
17+
from icevision.models.inference import *
1618

1719

1820
@torch.no_grad()
@@ -56,6 +58,9 @@ def predict(
5658
)
5759

5860

61+
end2end_detect = partial(_end2end_detect, predict_fn=predict)
62+
63+
5964
def predict_from_dl(
6065
model: nn.Module,
6166
infer_dl: DataLoader,

icevision/models/ross/efficientdet/prediction.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["predict", "predict_from_dl", "convert_raw_predictions"]
1+
__all__ = ["predict", "predict_from_dl", "convert_raw_predictions", "end2end_detect"]
22

33
from icevision.imports import *
44
from icevision.utils import *
@@ -7,6 +7,7 @@
77
from icevision.models.utils import _predict_from_dl
88
from icevision.models.ross.efficientdet.dataloaders import *
99
from effdet import DetBenchTrain, DetBenchPredict, unwrap_bench
10+
from icevision.models.inference import *
1011

1112

1213
@torch.no_grad()
@@ -111,3 +112,6 @@ def convert_raw_predictions(
111112
preds.append(Prediction(pred=pred, ground_truth=record))
112113

113114
return preds
115+
116+
117+
end2end_detect = partial(_end2end_detect, predict_fn=predict)

icevision/models/torchvision/faster_rcnn/prediction.py

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"predict_from_dl",
44
"convert_raw_prediction",
55
"convert_raw_predictions",
6+
"end2end_detect",
67
]
78

89
from icevision.imports import *
@@ -11,6 +12,7 @@
1112
from icevision.models.utils import _predict_from_dl
1213
from icevision.data import *
1314
from icevision.models.torchvision.faster_rcnn.dataloaders import *
15+
from icevision.models.inference import *
1416

1517

1618
@torch.no_grad()
@@ -133,3 +135,6 @@ def convert_raw_prediction(
133135
record.set_img(tensor_to_image(tensor_image))
134136

135137
return Prediction(pred=pred, ground_truth=record)
138+
139+
140+
end2end_detect = partial(_end2end_detect, predict_fn=predict)

icevision/models/torchvision/retinanet/prediction.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"predict_from_dl",
44
"convert_raw_prediction",
55
"convert_raw_predictions",
6+
"end2end_detect",
67
]
78

89
from icevision.models.torchvision.faster_rcnn.prediction import *

icevision/models/ultralytics/yolov5/prediction.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["predict", "predict_from_dl", "convert_raw_predictions"]
1+
__all__ = ["predict", "predict_from_dl", "convert_raw_predictions", "end2end_detect"]
22

33
from icevision.imports import *
44
from icevision.utils import *
@@ -7,6 +7,7 @@
77
from icevision.models.utils import _predict_from_dl
88
from icevision.models.ultralytics.yolov5.dataloaders import *
99
from yolov5.utils.general import non_max_suppression
10+
from icevision.models.inference import *
1011

1112

1213
@torch.no_grad()
@@ -117,3 +118,6 @@ def convert_raw_predictions(
117118
preds.append(Prediction(pred=pred, ground_truth=record))
118119

119120
return preds
121+
122+
123+
end2end_detect = partial(_end2end_detect, predict_fn=predict)

icevision/tfms/albumentations/albumentations_adapter.py

+9-42
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from icevision.utils import *
1818
from icevision.core import *
1919
from icevision.tfms.transform import *
20+
from icevision.tfms.albumentations.albumentations_helpers import (
21+
get_size_without_padding,
22+
get_transform,
23+
)
2024

2125

2226
@dataclass
@@ -269,7 +273,11 @@ def apply(self, record):
269273
self._albu_out = tfms(**self._albu_in)
270274

271275
# store additional info (might be used by components on `collect`)
272-
self._size_no_padding = self._get_size_without_padding(record)
276+
height, width, _ = self._albu_out["image"].shape
277+
height, width = get_size_without_padding(
278+
self.tfms_list, record.img, height, width
279+
)
280+
self._size_no_padding = ImgSize(width=width, height=height)
273281

274282
# collect results
275283
for collect_op in sorted(self._collect_ops, key=lambda x: x.order):
@@ -295,24 +303,6 @@ def _filter_attribute(self, v: list):
295303
assert len(v) == len(self._keep_mask)
296304
return [o for o, keep in zip(v, self._keep_mask) if keep]
297305

298-
def _get_size_without_padding(self, record) -> ImgSize:
299-
height, width, _ = self._albu_out["image"].shape
300-
301-
if get_transform(self.tfms_list, "Pad") is not None:
302-
after_pad_h, after_pad_w, _ = np.array(record.img).shape
303-
304-
t = get_transform(self.tfms_list, "SmallestMaxSize")
305-
if t is not None:
306-
presize = t.max_size
307-
height, width = _func_max_size(after_pad_h, after_pad_w, presize, min)
308-
309-
t = get_transform(self.tfms_list, "LongestMaxSize")
310-
if t is not None:
311-
size = t.max_size
312-
height, width = _func_max_size(after_pad_h, after_pad_w, size, max)
313-
314-
return ImgSize(width=width, height=height)
315-
316306

317307
def _flatten_tfms(t):
318308
flat = []
@@ -330,26 +320,3 @@ def _is_iter(o):
330320
return True
331321
except:
332322
return False
333-
334-
335-
def get_transform(tfms_list, t):
336-
for el in tfms_list:
337-
if t in str(type(el)):
338-
return el
339-
return None
340-
341-
342-
def py3round(number):
343-
"""Unified rounding in all python versions."""
344-
if abs(round(number) - number) == 0.5:
345-
return int(2.0 * round(number / 2.0))
346-
347-
return int(round(number))
348-
349-
350-
def _func_max_size(height, width, max_size, func):
351-
scale = max_size / float(func(width, height))
352-
353-
if scale != 1.0:
354-
height, width = tuple(py3round(dim * scale) for dim in (height, width))
355-
return height, width

0 commit comments

Comments
 (0)