|
| 1 | +import cv2 |
| 2 | +import numpy as np |
| 3 | + |
| 4 | + |
| 5 | +class PixelLinkDecoder(): |
| 6 | + def __init__(self): |
| 7 | + four_neighbours = False |
| 8 | + if four_neighbours: |
| 9 | + self._get_neighbours = self._get_neighbours_4 |
| 10 | + else: |
| 11 | + self._get_neighbours = self._get_neighbours_8 |
| 12 | + self.pixel_conf_threshold = 0.8 |
| 13 | + self.link_conf_threshold = 0.8 |
| 14 | + |
| 15 | + def decode(self, height, width, detections: dict): |
| 16 | + self.image_height = height |
| 17 | + self.image_width = width |
| 18 | + self.pixel_scores = self._set_pixel_scores(detections['model/segm_logits/add']) |
| 19 | + self.link_scores = self._set_link_scores(detections['model/link_logits_/add']) |
| 20 | + |
| 21 | + self.pixel_mask = self.pixel_scores >= self.pixel_conf_threshold |
| 22 | + self.link_mask = self.link_scores >= self.link_conf_threshold |
| 23 | + self.points = list(zip(*np.where(self.pixel_mask))) |
| 24 | + self.h, self.w = np.shape(self.pixel_mask) |
| 25 | + self.group_mask = dict.fromkeys(self.points, -1) |
| 26 | + self.bboxes = None |
| 27 | + self.root_map = None |
| 28 | + self.mask = None |
| 29 | + |
| 30 | + self._decode() |
| 31 | + |
| 32 | + def _softmax(self, x, axis=None): |
| 33 | + return np.exp(x - self._logsumexp(x, axis=axis, keepdims=True)) |
| 34 | + |
| 35 | + # pylint: disable=no-self-use |
| 36 | + def _logsumexp(self, a, axis=None, b=None, keepdims=False, return_sign=False): |
| 37 | + if b is not None: |
| 38 | + a, b = np.broadcast_arrays(a, b) |
| 39 | + if np.any(b == 0): |
| 40 | + a = a + 0. # promote to at least float |
| 41 | + a[b == 0] = -np.inf |
| 42 | + |
| 43 | + a_max = np.amax(a, axis=axis, keepdims=True) |
| 44 | + |
| 45 | + if a_max.ndim > 0: |
| 46 | + a_max[~np.isfinite(a_max)] = 0 |
| 47 | + elif not np.isfinite(a_max): |
| 48 | + a_max = 0 |
| 49 | + |
| 50 | + if b is not None: |
| 51 | + b = np.asarray(b) |
| 52 | + tmp = b * np.exp(a - a_max) |
| 53 | + else: |
| 54 | + tmp = np.exp(a - a_max) |
| 55 | + |
| 56 | + # suppress warnings about log of zero |
| 57 | + with np.errstate(divide='ignore'): |
| 58 | + s = np.sum(tmp, axis=axis, keepdims=keepdims) |
| 59 | + if return_sign: |
| 60 | + sgn = np.sign(s) |
| 61 | + s *= sgn # /= makes more sense but we need zero -> zero |
| 62 | + out = np.log(s) |
| 63 | + |
| 64 | + if not keepdims: |
| 65 | + a_max = np.squeeze(a_max, axis=axis) |
| 66 | + out += a_max |
| 67 | + |
| 68 | + if return_sign: |
| 69 | + return out, sgn |
| 70 | + else: |
| 71 | + return out |
| 72 | + |
| 73 | + def _set_pixel_scores(self, pixel_scores): |
| 74 | + "get softmaxed properly shaped pixel scores" |
| 75 | + tmp = np.transpose(pixel_scores, (0, 2, 3, 1)) |
| 76 | + return self._softmax(tmp, axis=-1)[0, :, :, 1] |
| 77 | + |
| 78 | + def _set_link_scores(self, link_scores): |
| 79 | + "get softmaxed properly shaped links scores" |
| 80 | + tmp = np.transpose(link_scores, (0, 2, 3, 1)) |
| 81 | + tmp_reshaped = tmp.reshape(tmp.shape[:-1] + (8, 2)) |
| 82 | + return self._softmax(tmp_reshaped, axis=-1)[0, :, :, :, 1] |
| 83 | + |
| 84 | + def _find_root(self, point): |
| 85 | + root = point |
| 86 | + update_parent = False |
| 87 | + tmp = self.group_mask[root] |
| 88 | + while tmp is not -1: |
| 89 | + root = tmp |
| 90 | + tmp = self.group_mask[root] |
| 91 | + update_parent = True |
| 92 | + if update_parent: |
| 93 | + self.group_mask[point] = root |
| 94 | + return root |
| 95 | + |
| 96 | + def _join(self, p1, p2): |
| 97 | + root1 = self._find_root(p1) |
| 98 | + root2 = self._find_root(p2) |
| 99 | + if root1 != root2: |
| 100 | + self.group_mask[root2] = root1 |
| 101 | + |
| 102 | + def _get_index(self, root): |
| 103 | + if root not in self.root_map: |
| 104 | + self.root_map[root] = len(self.root_map) + 1 |
| 105 | + return self.root_map[root] |
| 106 | + |
| 107 | + def _get_all(self): |
| 108 | + self.root_map = {} |
| 109 | + self.mask = np.zeros_like(self.pixel_mask, dtype=np.int32) |
| 110 | + |
| 111 | + for point in self.points: |
| 112 | + point_root = self._find_root(point) |
| 113 | + bbox_idx = self._get_index(point_root) |
| 114 | + self.mask[point] = bbox_idx |
| 115 | + |
| 116 | + def _get_neighbours_8(self, x, y): |
| 117 | + w, h = self.w, self.h |
| 118 | + tmp = [(0, x - 1, y - 1), (1, x, y - 1), |
| 119 | + (2, x + 1, y - 1), (3, x - 1, y), |
| 120 | + (4, x + 1, y), (5, x - 1, y + 1), |
| 121 | + (6, x, y + 1), (7, x + 1, y + 1)] |
| 122 | + |
| 123 | + return [i for i in tmp if i[1] >= 0 and i[1] < w and i[2] >= 0 and i[2] < h] |
| 124 | + |
| 125 | + def _get_neighbours_4(self, x, y): |
| 126 | + w, h = self.w, self.h |
| 127 | + tmp = [(1, x, y - 1), |
| 128 | + (3, x - 1, y), |
| 129 | + (4, x + 1, y), |
| 130 | + (6, x, y + 1)] |
| 131 | + |
| 132 | + return [i for i in tmp if i[1] >= 0 and i[1] < w and i[2] >= 0 and i[2] < h] |
| 133 | + |
| 134 | + def _mask_to_bboxes(self, min_area=300, min_height=10): |
| 135 | + self.bboxes = [] |
| 136 | + max_bbox_idx = self.mask.max() |
| 137 | + mask_tmp = cv2.resize(self.mask, (self.image_width, self.image_height), interpolation=cv2.INTER_NEAREST) |
| 138 | + |
| 139 | + for bbox_idx in range(1, max_bbox_idx + 1): |
| 140 | + bbox_mask = mask_tmp == bbox_idx |
| 141 | + cnts, _ = cv2.findContours(bbox_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
| 142 | + if len(cnts) == 0: |
| 143 | + continue |
| 144 | + cnt = cnts[0] |
| 145 | + rect, w, h = self._min_area_rect(cnt) |
| 146 | + if min(w, h) < min_height: |
| 147 | + continue |
| 148 | + if w * h < min_area: |
| 149 | + continue |
| 150 | + self.bboxes.append(self._order_points(rect)) |
| 151 | + |
| 152 | + # pylint: disable=no-self-use |
| 153 | + def _min_area_rect(self, cnt): |
| 154 | + rect = cv2.minAreaRect(cnt) |
| 155 | + w, h = rect[1] |
| 156 | + box = cv2.boxPoints(rect) |
| 157 | + box = np.int0(box) |
| 158 | + return box, w, h |
| 159 | + |
| 160 | + # pylint: disable=no-self-use |
| 161 | + def _order_points(self, rect): |
| 162 | + """ (x, y) |
| 163 | + Order: TL, TR, BR, BL |
| 164 | + """ |
| 165 | + tmp = np.zeros_like(rect) |
| 166 | + sums = rect.sum(axis=1) |
| 167 | + tmp[0] = rect[np.argmin(sums)] |
| 168 | + tmp[2] = rect[np.argmax(sums)] |
| 169 | + diff = np.diff(rect, axis=1) |
| 170 | + tmp[1] = rect[np.argmin(diff)] |
| 171 | + tmp[3] = rect[np.argmax(diff)] |
| 172 | + return tmp |
| 173 | + |
| 174 | + def _decode(self): |
| 175 | + for point in self.points: |
| 176 | + y, x = point |
| 177 | + neighbours = self._get_neighbours(x, y) |
| 178 | + for n_idx, nx, ny in neighbours: |
| 179 | + link_value = self.link_mask[y, x, n_idx] |
| 180 | + pixel_cls = self.pixel_mask[ny, nx] |
| 181 | + if link_value and pixel_cls: |
| 182 | + self._join(point, (ny, nx)) |
| 183 | + |
| 184 | + self._get_all() |
| 185 | + self._mask_to_bboxes() |
| 186 | + |
| 187 | + |
| 188 | +label = 1 |
| 189 | +pcd = PixelLinkDecoder() |
| 190 | +for detection in detections: |
| 191 | + frame = detection['frame_id'] |
| 192 | + pcd.decode(detection['frame_height'], detection['frame_width'], detection['detections']) |
| 193 | + for box in pcd.bboxes: |
| 194 | + box = [[int(b[0]), int(b[1])] for b in box] |
| 195 | + results.add_polygon(box, label, frame) |
0 commit comments