|
| 1 | +__all__ = ["process_bbox_predictions", "_end2end_detect"] |
| 2 | + |
| 3 | +from icevision.imports import * |
| 4 | +from icevision.core import * |
| 5 | +from icevision.data import * |
| 6 | +from icevision.tfms.albumentations.albumentations_helpers import ( |
| 7 | + get_size_without_padding, |
| 8 | +) |
| 9 | +from icevision.tfms.albumentations import albumentations_adapter |
| 10 | + |
| 11 | + |
| 12 | +def _end2end_detect( |
| 13 | + img: Union[PIL.Image.Image, Path, str], |
| 14 | + transforms: albumentations_adapter.Adapter, |
| 15 | + model: torch.nn.Module, |
| 16 | + class_map: ClassMap, |
| 17 | + detection_threshold: float = 0.5, |
| 18 | + predict_fn: Callable = None, |
| 19 | +): |
| 20 | + """ |
| 21 | + Run Object Detection inference (only `bboxes`) on a single image. |
| 22 | +
|
| 23 | + Parameters |
| 24 | + ---------- |
| 25 | + img: image to run inference on. Can be a string, Path or PIL.Image |
| 26 | + transforms: icevision albumentations transforms |
| 27 | + model: model to run inference with |
| 28 | + class_map: ClassMap with the available categories |
| 29 | + detection_threshold: confidence threshold below which boxes are discarded |
| 30 | +
|
| 31 | + Returns |
| 32 | + ------- |
| 33 | + List of dicts with category, score and bbox coordinates adjusted to original image size and aspect ratio |
| 34 | + """ |
| 35 | + if isinstance(img, (str, Path)): |
| 36 | + img = PIL.Image.open(Path(img)) |
| 37 | + |
| 38 | + infer_ds = Dataset.from_images([np.array(img)], transforms, class_map=class_map) |
| 39 | + pred = predict_fn(model, infer_ds, detection_threshold=detection_threshold)[0] |
| 40 | + bboxes = process_bbox_predictions(pred, img, transforms.tfms_list) |
| 41 | + return bboxes |
| 42 | + |
| 43 | + |
| 44 | +def process_bbox_predictions( |
| 45 | + pred: Prediction, |
| 46 | + img: PIL.Image.Image, |
| 47 | + transforms: List[Any], |
| 48 | +) -> List[Dict[str, Any]]: |
| 49 | + """ |
| 50 | + Postprocess prediction. |
| 51 | +
|
| 52 | + Parameters |
| 53 | + ---------- |
| 54 | + pred: icevision prediction object |
| 55 | + img: original image, before any model-pre-processing done |
| 56 | + transforms: list of model-pre-processing transforms |
| 57 | +
|
| 58 | + Returns |
| 59 | + ------- |
| 60 | + List of dicts with class, score and bbox coordinates |
| 61 | + """ |
| 62 | + bboxes = [] |
| 63 | + for bbox, score, label in zip( |
| 64 | + pred.pred.detection.bboxes, |
| 65 | + pred.pred.detection.scores, |
| 66 | + pred.pred.detection.labels, |
| 67 | + ): |
| 68 | + xmin, ymin, xmax, ymax = postprocess_bbox( |
| 69 | + img, bbox, transforms, pred.pred.height, pred.pred.width |
| 70 | + ) |
| 71 | + result = { |
| 72 | + "class": label, |
| 73 | + "score": score, |
| 74 | + "bbox": [xmin, ymin, xmax, ymax], |
| 75 | + } |
| 76 | + bboxes.append(result) |
| 77 | + return bboxes |
| 78 | + |
| 79 | + |
| 80 | +def postprocess_bbox( |
| 81 | + img: PIL.Image.Image, bbox: BBox, transforms: List[Any], h_after: int, w_after: int |
| 82 | +) -> Tuple[int, int, int, int]: |
| 83 | + """ |
| 84 | + Post-process predicted bbox to adjust coordinates to input image size. |
| 85 | +
|
| 86 | + Parameters |
| 87 | + ---------- |
| 88 | + img: original image, before any model-pre-processing done |
| 89 | + bbox: predicted bbox |
| 90 | + transforms: list of model-pre-processing transforms |
| 91 | + h_after: height of image after model-pre-processing transforms |
| 92 | + w_after: width of image after model-pre-processing transforms |
| 93 | +
|
| 94 | + Returns |
| 95 | + ------- |
| 96 | + Tuple with (xmin, ymin, xmax, ymax) rescaled and re-adjusted to match the original image size |
| 97 | + """ |
| 98 | + w_before, h_before = img.size |
| 99 | + h_after, w_after = get_size_without_padding(transforms, img, h_after, w_after) |
| 100 | + pad = np.abs(h_after - w_after) // 2 |
| 101 | + |
| 102 | + h_scale, w_scale = h_after / h_before, w_after / w_before |
| 103 | + if h_after < w_after: |
| 104 | + xmin, xmax, ymin, ymax = ( |
| 105 | + int(bbox.xmin), |
| 106 | + int(bbox.xmax), |
| 107 | + int(bbox.ymin) - pad, |
| 108 | + int(bbox.ymax) - pad, |
| 109 | + ) |
| 110 | + else: |
| 111 | + xmin, xmax, ymin, ymax = ( |
| 112 | + int(bbox.xmin) - pad, |
| 113 | + int(bbox.xmax) - pad, |
| 114 | + int(bbox.ymin), |
| 115 | + int(bbox.ymax), |
| 116 | + ) |
| 117 | + |
| 118 | + xmin, xmax, ymin, ymax = ( |
| 119 | + max(xmin, 0), |
| 120 | + min(xmax, w_after), |
| 121 | + max(ymin, 0), |
| 122 | + min(ymax, h_after), |
| 123 | + ) |
| 124 | + xmin, xmax, ymin, ymax = ( |
| 125 | + int(xmin / w_scale), |
| 126 | + int(xmax / w_scale), |
| 127 | + int(ymin / h_scale), |
| 128 | + int(ymax / h_scale), |
| 129 | + ) |
| 130 | + |
| 131 | + return xmin, ymin, xmax, ymax |
0 commit comments