Skip to content

Commit 98e851a

Browse files
benhoffnmanovic
authored andcommitted
added in new interp files for pixel link v0004 (cvat-ai#852)
1 parent 3b6961f commit 98e851a

File tree

3 files changed

+200
-0
lines changed

3 files changed

+200
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"label_map": {
3+
"1": "text"
4+
}
5+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import cv2
2+
import numpy as np
3+
4+
5+
class PixelLinkDecoder():
6+
def __init__(self):
7+
four_neighbours = False
8+
if four_neighbours:
9+
self._get_neighbours = self._get_neighbours_4
10+
else:
11+
self._get_neighbours = self._get_neighbours_8
12+
self.pixel_conf_threshold = 0.8
13+
self.link_conf_threshold = 0.8
14+
15+
def decode(self, height, width, detections: dict):
16+
self.image_height = height
17+
self.image_width = width
18+
self.pixel_scores = self._set_pixel_scores(detections['model/segm_logits/add'])
19+
self.link_scores = self._set_link_scores(detections['model/link_logits_/add'])
20+
21+
self.pixel_mask = self.pixel_scores >= self.pixel_conf_threshold
22+
self.link_mask = self.link_scores >= self.link_conf_threshold
23+
self.points = list(zip(*np.where(self.pixel_mask)))
24+
self.h, self.w = np.shape(self.pixel_mask)
25+
self.group_mask = dict.fromkeys(self.points, -1)
26+
self.bboxes = None
27+
self.root_map = None
28+
self.mask = None
29+
30+
self._decode()
31+
32+
def _softmax(self, x, axis=None):
33+
return np.exp(x - self._logsumexp(x, axis=axis, keepdims=True))
34+
35+
# pylint: disable=no-self-use
36+
def _logsumexp(self, a, axis=None, b=None, keepdims=False, return_sign=False):
37+
if b is not None:
38+
a, b = np.broadcast_arrays(a, b)
39+
if np.any(b == 0):
40+
a = a + 0. # promote to at least float
41+
a[b == 0] = -np.inf
42+
43+
a_max = np.amax(a, axis=axis, keepdims=True)
44+
45+
if a_max.ndim > 0:
46+
a_max[~np.isfinite(a_max)] = 0
47+
elif not np.isfinite(a_max):
48+
a_max = 0
49+
50+
if b is not None:
51+
b = np.asarray(b)
52+
tmp = b * np.exp(a - a_max)
53+
else:
54+
tmp = np.exp(a - a_max)
55+
56+
# suppress warnings about log of zero
57+
with np.errstate(divide='ignore'):
58+
s = np.sum(tmp, axis=axis, keepdims=keepdims)
59+
if return_sign:
60+
sgn = np.sign(s)
61+
s *= sgn # /= makes more sense but we need zero -> zero
62+
out = np.log(s)
63+
64+
if not keepdims:
65+
a_max = np.squeeze(a_max, axis=axis)
66+
out += a_max
67+
68+
if return_sign:
69+
return out, sgn
70+
else:
71+
return out
72+
73+
def _set_pixel_scores(self, pixel_scores):
74+
"get softmaxed properly shaped pixel scores"
75+
tmp = np.transpose(pixel_scores, (0, 2, 3, 1))
76+
return self._softmax(tmp, axis=-1)[0, :, :, 1]
77+
78+
def _set_link_scores(self, link_scores):
79+
"get softmaxed properly shaped links scores"
80+
tmp = np.transpose(link_scores, (0, 2, 3, 1))
81+
tmp_reshaped = tmp.reshape(tmp.shape[:-1] + (8, 2))
82+
return self._softmax(tmp_reshaped, axis=-1)[0, :, :, :, 1]
83+
84+
def _find_root(self, point):
85+
root = point
86+
update_parent = False
87+
tmp = self.group_mask[root]
88+
while tmp is not -1:
89+
root = tmp
90+
tmp = self.group_mask[root]
91+
update_parent = True
92+
if update_parent:
93+
self.group_mask[point] = root
94+
return root
95+
96+
def _join(self, p1, p2):
97+
root1 = self._find_root(p1)
98+
root2 = self._find_root(p2)
99+
if root1 != root2:
100+
self.group_mask[root2] = root1
101+
102+
def _get_index(self, root):
103+
if root not in self.root_map:
104+
self.root_map[root] = len(self.root_map) + 1
105+
return self.root_map[root]
106+
107+
def _get_all(self):
108+
self.root_map = {}
109+
self.mask = np.zeros_like(self.pixel_mask, dtype=np.int32)
110+
111+
for point in self.points:
112+
point_root = self._find_root(point)
113+
bbox_idx = self._get_index(point_root)
114+
self.mask[point] = bbox_idx
115+
116+
def _get_neighbours_8(self, x, y):
117+
w, h = self.w, self.h
118+
tmp = [(0, x - 1, y - 1), (1, x, y - 1),
119+
(2, x + 1, y - 1), (3, x - 1, y),
120+
(4, x + 1, y), (5, x - 1, y + 1),
121+
(6, x, y + 1), (7, x + 1, y + 1)]
122+
123+
return [i for i in tmp if i[1] >= 0 and i[1] < w and i[2] >= 0 and i[2] < h]
124+
125+
def _get_neighbours_4(self, x, y):
126+
w, h = self.w, self.h
127+
tmp = [(1, x, y - 1),
128+
(3, x - 1, y),
129+
(4, x + 1, y),
130+
(6, x, y + 1)]
131+
132+
return [i for i in tmp if i[1] >= 0 and i[1] < w and i[2] >= 0 and i[2] < h]
133+
134+
def _mask_to_bboxes(self, min_area=300, min_height=10):
135+
self.bboxes = []
136+
max_bbox_idx = self.mask.max()
137+
mask_tmp = cv2.resize(self.mask, (self.image_width, self.image_height), interpolation=cv2.INTER_NEAREST)
138+
139+
for bbox_idx in range(1, max_bbox_idx + 1):
140+
bbox_mask = mask_tmp == bbox_idx
141+
cnts, _ = cv2.findContours(bbox_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
142+
if len(cnts) == 0:
143+
continue
144+
cnt = cnts[0]
145+
rect, w, h = self._min_area_rect(cnt)
146+
if min(w, h) < min_height:
147+
continue
148+
if w * h < min_area:
149+
continue
150+
self.bboxes.append(self._order_points(rect))
151+
152+
# pylint: disable=no-self-use
153+
def _min_area_rect(self, cnt):
154+
rect = cv2.minAreaRect(cnt)
155+
w, h = rect[1]
156+
box = cv2.boxPoints(rect)
157+
box = np.int0(box)
158+
return box, w, h
159+
160+
# pylint: disable=no-self-use
161+
def _order_points(self, rect):
162+
""" (x, y)
163+
Order: TL, TR, BR, BL
164+
"""
165+
tmp = np.zeros_like(rect)
166+
sums = rect.sum(axis=1)
167+
tmp[0] = rect[np.argmin(sums)]
168+
tmp[2] = rect[np.argmax(sums)]
169+
diff = np.diff(rect, axis=1)
170+
tmp[1] = rect[np.argmin(diff)]
171+
tmp[3] = rect[np.argmax(diff)]
172+
return tmp
173+
174+
def _decode(self):
175+
for point in self.points:
176+
y, x = point
177+
neighbours = self._get_neighbours(x, y)
178+
for n_idx, nx, ny in neighbours:
179+
link_value = self.link_mask[y, x, n_idx]
180+
pixel_cls = self.pixel_mask[ny, nx]
181+
if link_value and pixel_cls:
182+
self._join(point, (ny, nx))
183+
184+
self._get_all()
185+
self._mask_to_bboxes()
186+
187+
188+
label = 1
189+
pcd = PixelLinkDecoder()
190+
for detection in detections:
191+
frame = detection['frame_id']
192+
pcd.decode(detection['frame_height'], detection['frame_width'], detection['detections'])
193+
for box in pcd.bboxes:
194+
box = [[int(b[0]), int(b[1])] for b in box]
195+
results.add_polygon(box, label, frame)

0 commit comments

Comments
 (0)