forked from Zzh-tju/CIoU
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoutput_utils.py
188 lines (142 loc) · 6.77 KB
/
output_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
""" Contains functions used to sanitize and prepare the output of Yolact. """
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from data import cfg, mask_type, MEANS, STD, activation_func
from utils.augmentations import Resize
from utils import timer
from .box_utils import crop, sanitize_coordinates
def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
visualize_lincomb=False, crop_masks=True, score_threshold=0):
"""
Postprocesses the output of Yolact on testing mode into a format that makes sense,
accounting for all the possible configuration settings.
Args:
- det_output: The lost of dicts that Detect outputs.
- w: The real with of the image.
- h: The real height of the image.
- batch_idx: If you have multiple images for this batch, the image's index in the batch.
- interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate)
Returns 4 torch Tensors (in the following order):
- classes [num_det]: The class idx for each detection.
- scores [num_det]: The confidence score for each detection.
- boxes [num_det, 4]: The bounding box for each detection in absolute point form.
- masks [num_det, h, w]: Full image masks for each detection.
"""
dets = det_output[batch_idx]
net = dets['net']
dets = dets['detection']
if dets is None:
return [torch.Tensor()] * 4 # Warning, this is 4 copies of the same thing
if score_threshold > 0:
keep = dets['score'] > score_threshold
for k in dets:
if k != 'proto':
dets[k] = dets[k][keep]
if dets['score'].size(0) == 0:
return [torch.Tensor()] * 4
# Actually extract everything from dets now
classes = dets['class']
boxes = dets['box']
scores = dets['score']
masks = dets['mask']
if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch:
# At this points masks is only the coefficients
proto_data = dets['proto']
# Test flag, do not upvote
if cfg.mask_proto_debug:
np.save('scripts/proto.npy', proto_data.cpu().numpy())
if visualize_lincomb:
display_lincomb(proto_data, masks)
masks = proto_data @ masks.t()
masks = cfg.mask_proto_mask_activation(masks)
# Crop masks before upsampling because you know why
if crop_masks:
masks = crop(masks, boxes)
# Permute into the correct output shape [num_dets, proto_h, proto_w]
masks = masks.permute(2, 0, 1).contiguous()
if cfg.use_maskiou:
with timer.env('maskiou_net'):
with torch.no_grad():
maskiou_p = net.maskiou_net(masks.unsqueeze(1))
maskiou_p = torch.gather(maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1)
if cfg.rescore_mask:
if cfg.rescore_bbox:
scores = scores * maskiou_p
else:
scores = [scores, scores * maskiou_p]
# Scale masks up to the full image
masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0)
# Binarize the masks
masks.gt_(0.5)
boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], w, cast=False)
boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], h, cast=False)
boxes = boxes.long()
if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch:
# Upscale masks
full_masks = torch.zeros(masks.size(0), h, w)
for jdx in range(masks.size(0)):
x1, y1, x2, y2 = boxes[jdx, :]
mask_w = x2 - x1
mask_h = y2 - y1
# Just in case
if mask_w * mask_h <= 0 or mask_w < 0:
continue
mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size)
mask = F.interpolate(mask, (mask_h, mask_w), mode=interpolation_mode, align_corners=False)
mask = mask.gt(0.5).float()
full_masks[jdx, y1:y2, x1:x2] = mask
masks = full_masks
return classes, scores, boxes, masks
def undo_image_transformation(img, w, h):
"""
Takes a transformed image tensor and returns a numpy ndarray that is untransformed.
Arguments w and h are the original height and width of the image.
"""
img_numpy = img.permute(1, 2, 0).cpu().numpy()
img_numpy = img_numpy[:, :, (2, 1, 0)] # To BRG
if cfg.backbone.transform.normalize:
img_numpy = (img_numpy * np.array(STD) + np.array(MEANS)) / 255.0
elif cfg.backbone.transform.subtract_means:
img_numpy = (img_numpy / 255.0 + np.array(MEANS) / 255.0).astype(np.float32)
img_numpy = img_numpy[:, :, (2, 1, 0)] # To RGB
img_numpy = np.clip(img_numpy, 0, 1)
return cv2.resize(img_numpy, (w,h))
def display_lincomb(proto_data, masks):
out_masks = torch.matmul(proto_data, masks.t())
# out_masks = cfg.mask_proto_mask_activation(out_masks)
for kdx in range(1):
jdx = kdx + 0
import matplotlib.pyplot as plt
coeffs = masks[jdx, :].cpu().numpy()
idx = np.argsort(-np.abs(coeffs))
# plt.bar(list(range(idx.shape[0])), coeffs[idx])
# plt.show()
coeffs_sort = coeffs[idx]
arr_h, arr_w = (4,8)
proto_h, proto_w, _ = proto_data.size()
arr_img = np.zeros([proto_h*arr_h, proto_w*arr_w])
arr_run = np.zeros([proto_h*arr_h, proto_w*arr_w])
test = torch.sum(proto_data, -1).cpu().numpy()
for y in range(arr_h):
for x in range(arr_w):
i = arr_w * y + x
if i == 0:
running_total = proto_data[:, :, idx[i]].cpu().numpy() * coeffs_sort[i]
else:
running_total += proto_data[:, :, idx[i]].cpu().numpy() * coeffs_sort[i]
running_total_nonlin = running_total
if cfg.mask_proto_mask_activation == activation_func.sigmoid:
running_total_nonlin = (1/(1+np.exp(-running_total_nonlin)))
arr_img[y*proto_h:(y+1)*proto_h, x*proto_w:(x+1)*proto_w] = (proto_data[:, :, idx[i]] / torch.max(proto_data[:, :, idx[i]])).cpu().numpy() * coeffs_sort[i]
arr_run[y*proto_h:(y+1)*proto_h, x*proto_w:(x+1)*proto_w] = (running_total_nonlin > 0.5).astype(np.float)
plt.imshow(arr_img)
plt.show()
# plt.imshow(arr_run)
# plt.show()
# plt.imshow(test)
# plt.show()
plt.imshow(out_masks[:, :, jdx].cpu().numpy())
plt.show()