Skip to content

Commit

Permalink
fast nms added
Browse files Browse the repository at this point in the history
  • Loading branch information
anshkumar committed May 27, 2022
1 parent fa04a07 commit 1b7d28b
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 267 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ protoc protos/*.proto --python_out=.

## Compile tensorflow addon for DCNv2 support (YOLACT++)
1. Git clone https://github.com/tensorflow/addons.
2. Apply the patch named `deformable_conv2d.patch`.
2. Apply the patch named `deformable_conv2d.patch` (`git am -3 < deformable_conv2d.patch`).
3. Compile tensorflow addon. For example for cuda 10.1
```
# Only CUDA 10.1 Update 1
Expand Down
126 changes: 3 additions & 123 deletions data/anchor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from itertools import product
from math import sqrt

from utils import utils
import tensorflow as tf

# Can generate one instance only when creating the model
Expand Down Expand Up @@ -53,126 +53,6 @@ def _generate_anchors(self, feature_map_size, aspect_ratio, scale):
output = tf.cast(output, tf.float32)
return num_anchors, output

def _encode(self, map_loc, anchors, include_variances=True):
# For variance in priorbox layer:
# https://github.com/weiliu89/caffe/issues/155

# center_gt = tf.map_fn(lambda x: map_to_center_form(x), map_loc)
# center_anchors in [ymin, xmin, ymax, xmax ]
# map_loc in [ymin, xmin, ymax, xmax ]
gh = map_loc[:, 2] - map_loc[:, 0]
gw = map_loc[:, 3] - map_loc[:, 1]
center_gt = tf.cast(tf.stack(
[map_loc[:, 1] + (gw / 2),
map_loc[:, 0] + (gh / 2), gw, gh],
axis=-1), tf.float32)

ph = anchors[:, 2] - anchors[:, 0]
pw = anchors[:, 3] - anchors[:, 1]
center_anchors = tf.cast(tf.stack(
[anchors[:, 1] + (pw / 2),
anchors[:, 0] + (ph / 2), pw, ph],
axis=-1), tf.float32)
variances = [0.1, 0.2]

# calculate offset
if include_variances:
g_hat_cx = (center_gt[:, 0] - center_anchors[:, 0]
) / center_anchors[:, 2] / variances[0]
g_hat_cy = (center_gt[:, 1] - center_anchors[:, 1]
) / center_anchors[:, 3] / variances[0]
else:
g_hat_cx = (center_gt[:, 0] - center_anchors[:, 0]
) / center_anchors[:, 2]
g_hat_cy = (center_gt[:, 1] - center_anchors[:, 1]
) / center_anchors[:, 3]
tf.debugging.assert_non_negative(center_anchors[:, 2] / center_gt[:, 2])
tf.debugging.assert_non_negative(center_anchors[:, 3] / center_gt[:, 3])
if include_variances:
g_hat_w = tf.math.log(center_gt[:, 2] / center_anchors[:, 2]
) / variances[1]
g_hat_h = tf.math.log(center_gt[:, 3] / center_anchors[:, 3]
) / variances[1]
else:
g_hat_w = tf.math.log(center_gt[:, 2] / center_anchors[:, 2])
g_hat_h = tf.math.log(center_gt[:, 3] / center_anchors[:, 3])
tf.debugging.assert_all_finite(g_hat_cx,
"Ground truth box x encoding NaN/Inf")
tf.debugging.assert_all_finite(g_hat_cy,
"Ground truth box y encoding NaN/Inf")
tf.debugging.assert_all_finite(g_hat_w,
"Ground truth box width encoding NaN/Inf")
tf.debugging.assert_all_finite(g_hat_h,
"Ground truth box height encoding NaN/Inf")
offsets = tf.stack([g_hat_cx, g_hat_cy, g_hat_w, g_hat_h], axis=-1)

return offsets

def _area(self, boxlist, scope=None):
# https://github.com/tensorflow/models/blob/831281cedfc8a4a0ad7c0c37173
# 963fafb99da37/official/vision/detection/utils/object_detection/
# box_list_ops.py#L48

"""Computes area of boxes.
Args:
boxlist: BoxList holding N boxes
scope: name scope.
Returns:
a tensor with shape [N] representing box areas.
"""
y_min, x_min, y_max, x_max = tf.split(
value=boxlist, num_or_size_splits=4, axis=1)
return tf.squeeze((y_max - y_min) * (x_max - x_min), [1])

def _intersection(self, boxlist1, boxlist2, scope=None):
# https://github.com/tensorflow/models/blob/831281cedfc8a4a0ad7c0c37173
# 963fafb99da37/official/vision/detection/utils/object_detection/
# box_list_ops.py#L209

"""Compute pairwise intersection areas between boxes.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
scope: name scope.
Returns:
a tensor with shape [N, M] representing pairwise intersections
"""
y_min1, x_min1, y_max1, x_max1 = tf.split(
value=boxlist1, num_or_size_splits=4, axis=1)
y_min2, x_min2, y_max2, x_max2 = tf.split(
value=boxlist2, num_or_size_splits=4, axis=1)
all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(y_max2))
all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(y_min2))
intersect_heights = tf.maximum(0.0,
all_pairs_min_ymax - all_pairs_max_ymin)
all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(x_max2))
all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(x_min2))
intersect_widths = tf.maximum(0.0,
all_pairs_min_xmax - all_pairs_max_xmin)
return intersect_heights * intersect_widths

def _iou(self, boxlist1, boxlist2, scope=None):
# https://github.com/tensorflow/models/blob/831281cedfc8a4a0ad7c0c37173
# 963fafb99da37/official/vision/detection/utils/object_detection/
# box_list_ops.py#L259

"""Computes pairwise intersection-over-union between box collections.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
scope: name scope.
Returns:
a tensor with shape [N, M] representing pairwise iou scores.
"""
intersections = self._intersection(boxlist1, boxlist2)
areas1 = self._area(boxlist1)
areas2 = self._area(boxlist2)
unions = (tf.expand_dims(areas1, 1) + tf.expand_dims(
areas2, 0) - intersections)
return tf.where(
tf.equal(intersections, 0.0),
tf.zeros_like(intersections), tf.truediv(intersections, unions))

def get_anchors(self):
# Convert anchors from [cx, cy, w, h] to [ymin, xmin, ymax, xmax ]
# for IOU calculations
Expand All @@ -192,7 +72,7 @@ def matching(self, pos_thresh, neg_thresh, gt_bbox, gt_labels):
# ground_truth clong the columns

# anchors and gt_bbox in [y1, x1, y2, x2]
pairwise_iou = self._iou(self.anchors, gt_bbox)
pairwise_iou = utils._iou(self.anchors, gt_bbox)

# size [num_priors]; iou with ground truth with the anchors
each_prior_max = tf.reduce_max(pairwise_iou, axis=-1)
Expand Down Expand Up @@ -243,6 +123,6 @@ def matching(self, pos_thresh, neg_thresh, gt_bbox, gt_labels):
tf.zeros(tf.size(background_label_index), dtype=tf.int64))

# anchors and each_prior_box in [y1, x1, y2, x2]
offsets = self._encode(each_prior_box, self.anchors)
offsets = utils._encode(each_prior_box, self.anchors)

return offsets, conf, each_prior_box, each_prior_index
1 change: 0 additions & 1 deletion data/yolact_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from data import tfrecord_decoder
from utils import augmentation
from utils.utils import normalize_image
from functools import partial

class Parser(object):
Expand Down
190 changes: 92 additions & 98 deletions detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, num_classes, max_output_size, per_class_max_output_size, conf
self.max_output_size = 300
self.per_class_max_output_size = 100

def __call__(self, net_outs, img_shape, trad_nms=True, use_cropped_mask=True):
def __call__(self, net_outs, img_shape, trad_nms=False, use_cropped_mask=True):
"""
Args:
pred_offset: (tensor) Loc preds from loc layers
Expand Down Expand Up @@ -76,13 +76,13 @@ def __call__(self, net_outs, img_shape, trad_nms=True, use_cropped_mask=True):
raw_anchors = tf.boolean_mask(anchors, class_p_max[b] > self.conf_thresh)

# decode only selected boxes
boxes = self._decode(raw_boxes, raw_anchors) # [27429, 4]
boxes = utils._decode(raw_boxes, raw_anchors) # [27429, 4]

if tf.size(class_thre) != 0:
if tf.size(class_thre) > 0:
if not trad_nms:
boxes, coef_thre, class_ids, class_thre = _fast_nms(boxes, coef_thre, class_thre)
boxes, coef_thre, class_ids, class_thre = self._cc_fast_nms(boxes, coef_thre, class_thre)
else:
boxes, coef_thre, class_ids, class_thre = self._traditional_nms(boxes, coef_thre, class_thre, score_threshold=self.conf_thresh, iou_threshold=self.nms_thresh, max_class_output_size=self.per_class_max_output_size)
boxes, coef_thre, class_ids, class_thre = self._traditional_nms_v2(boxes, coef_thre, class_thre, score_threshold=self.conf_thresh, iou_threshold=self.nms_thresh)

num_detection = [tf.shape(boxes)[0]]

Expand Down Expand Up @@ -116,105 +116,15 @@ def __call__(self, net_outs, img_shape, trad_nms=True, use_cropped_mask=True):
result = {'detection_boxes': detection_boxes,'detection_classes': detection_classes, 'detection_scores': detection_scores, 'detection_masks': detection_masks, 'num_detections': num_detections}
return result

def _batch_decode(self, box_p, priors, include_variances=True):
# https://github.com/feiyuhuahuo/Yolact_minimal/blob/9299a0cf346e455d672fadd796ac748871ba85e4/utils/box_utils.py#L151
"""
Decode predicted bbox coordinates using the scheme
employed at https://lilianweng.github.io/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html
b_x = prior_w*loc_x + prior_x
b_y = prior_h*loc_y + prior_y
b_w = prior_w * exp(loc_w)
b_h = prior_h * exp(loc_h)
Note that loc is inputed as [c_x, x_y, w, h]
while priors are inputed as [c_x, c_y, w, h] where each coordinate
is relative to size of the image.
Also note that prior_x and prior_y are center coordinates.
"""
variances = [0.1, 0.2]
box_p = tf.cast(box_p, tf.float32)
priors = tf.cast(priors, tf.float32)
if include_variances:
b_x_y = priors[:, :2] + box_p[:, :, :2] * priors[:, 2:]* variances[0]
b_w_h = priors[:, 2:] * tf.math.exp(box_p[:, :, 2:]* variances[1])
else:
b_x_y = priors[:, :2] + box_p[:, :, :2] * priors[:, 2:]
b_w_h = priors[:, 2:] * tf.math.exp(box_p[:, :, 2:])

boxes = tf.concat([b_x_y, b_w_h], axis=-1)

# [x_min, y_min, x_max, y_max]
boxes = tf.concat([boxes[:, :, :2] - boxes[:, :, 2:] / 2, boxes[:, :, 2:] / 2 + boxes[:, :, :2]], axis=-1)

# [y_min, x_min, y_max, x_max]
return tf.stack([boxes[:, :, 1], boxes[:, :, 0],boxes[:, :, 3], boxes[:, :, 2]], axis=-1)

def _decode(self, box_p, priors, include_variances=True):
# https://github.com/feiyuhuahuo/Yolact_minimal/blob/9299a0cf346e455d672fadd796ac748871ba85e4/utils/box_utils.py#L151
"""
Decode predicted bbox coordinates using the scheme
employed at https://lilianweng.github.io/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html
b_x = prior_w*loc_x + prior_x
b_y = prior_h*loc_y + prior_y
b_w = prior_w * exp(loc_w)
b_h = prior_h * exp(loc_h)
Note that loc is inputed as [c_x, x_y, w, h]
while priors are inputed as [c_x, c_y, w, h] where each coordinate
is relative to size of the image.
Also note that prior_x and prior_y are center coordinates.
"""
variances = [0.1, 0.2]
box_p = tf.cast(box_p, tf.float32)
priors = tf.cast(priors, tf.float32)

ph = priors[:, 2] - priors[:, 0]
pw = priors[:, 3] - priors[:, 1]
priors = tf.cast(tf.stack(
[priors[:, 1] + (pw / 2),
priors[:, 0] + (ph / 2), pw, ph],
axis=-1), tf.float32)

if include_variances:
b_x_y = priors[:, :2] + box_p[:, :2] * priors[:, 2:]* variances[0]
b_w_h = priors[:, 2:] * tf.math.exp(box_p[:, 2:]* variances[1])
else:
b_x_y = priors[:, :2] + box_p[:, :2] * priors[:, 2:]
b_w_h = priors[:, 2:] * tf.math.exp(box_p[:, 2:])

boxes = tf.concat([b_x_y, b_w_h], axis=-1)

# [x_min, y_min, x_max, y_max]
boxes = tf.concat([boxes[:, :2] - boxes[:, 2:] / 2, boxes[:, 2:] / 2 + boxes[:, :2]], axis=-1)

# [y_min, x_min, y_max, x_max]
return tf.stack([boxes[:, 1], boxes[:, 0],boxes[:, 3], boxes[:, 2]], axis=-1)

def _sanitize_coordinates(self, _x1, _x2, size, padding: int = 0):
"""
Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0, and x2 <= image_size.
Also converts from relative to absolute coordinates and casts the results to long tensors.
Warning: this does things in-place behind the scenes so copy if necessary.
"""
x1 = tf.math.minimum(_x1, _x2)
x2 = tf.math.maximum(_x1, _x2)
x1 = tf.clip_by_value(x1 - padding, clip_value_min=0.0, clip_value_max=tf.cast(size,tf.float32))
x2 = tf.clip_by_value(x2 + padding, clip_value_min=0.0, clip_value_max=tf.cast(size,tf.float32))

# Normalize the coordinates
return x1, x2

def _sanitize(self, boxes, width, height, padding: int = 0, crop_size=(30,30)):
def _sanitize(self, boxes, width, height, padding: int = 0):
"""
"Crop" predicted masks by zeroing out everything not in the predicted bbox.
Args:
- masks should be a size [h, w, n] tensor of masks
- boxes should be a size [n, 4] tensor of bbox coords in relative point form
"""
x1, x2 = self._sanitize_coordinates(boxes[:, 1], boxes[:, 3], width, padding)
y1, y2 = self._sanitize_coordinates(boxes[:, 0], boxes[:, 2], height, padding)
x1, x2 = utils._sanitize_coordinates(boxes[:, 1], boxes[:, 3], width, normalized=False)
y1, y2 = utils._sanitize_coordinates(boxes[:, 0], boxes[:, 2], height, normalized=False)

boxes = tf.stack((y1, x1, y2, x2), axis=1)

Expand Down Expand Up @@ -252,6 +162,90 @@ def _traditional_nms(self, boxes, mask_coef, scores, iou_threshold=0.5, score_th
boxes = tf.gather(_boxes, _ids)[:max_output_size]
mask_coef = tf.gather(_coefs, _ids)[:max_output_size]
classes = tf.gather(_classes, _ids)[:max_output_size]
return boxes, mask_coef, classes, scores

def _traditional_nms_v2(self, boxes, mask_coef, scores, iou_threshold=0.5, score_threshold=0.05, max_output_size=300):
selected_indices = tf.image.non_max_suppression(boxes,
tf.reduce_max(scores, axis=-1),
max_output_size=max_output_size,
iou_threshold=iou_threshold,
score_threshold=score_threshold)

classes = tf.argmax(scores, axis=-1)+1
boxes = tf.gather(boxes, selected_indices)
scores = tf.gather(tf.reduce_max(scores, axis=-1), selected_indices)
mask_coef = tf.gather(mask_coef, selected_indices)
classes = tf.cast(tf.gather(classes, selected_indices), dtype=tf.float32)
return boxes, mask_coef, classes, scores

def _cc_fast_nms(self, boxes, masks, scores, iou_threshold:float=0.5, top_k:int=15):
# Cross Class NMS
# Collapse all the classes into 1
classes = tf.argmax(scores, axis=-1)+1
scores = tf.reduce_max(scores, axis=-1)
_, idx = tf.math.top_k(scores, k=tf.math.minimum(top_k, tf.shape(scores)[0]))
boxes_idx = tf.gather(boxes, idx, axis=0)

# Compute the pairwise IoU between the boxes
iou = utils._iou(boxes_idx, boxes_idx)

# Zero out the lower triangle of the cosine similarity matrix and diagonal
iou = tf.linalg.band_part(iou, 0, -1) - tf.linalg.band_part(iou, 0, 0)

# Now that everything in the diagonal and below is zeroed out, if we take the max
# of the IoU matrix along the columns, each column will represent the maximum IoU
# between this element and every element with a higher score than this element.
iou_max = tf.reduce_max(iou, axis=0)

# Now just filter out the ones greater than the threshold, i.e., only keep boxes that
# don't have a higher scoring box that would supress it in normal NMS.
idx_det = (iou_max <= iou_threshold)
idx_det = tf.where(idx_det == True)

classes = tf.gather_nd(classes, idx_det)
boxes = tf.gather_nd(boxes, idx_det)
masks = tf.gather_nd(masks, idx_det)
scores = tf.gather_nd(scores, idx_det)

return boxes, masks, classes, scores

def _fast_nms(self, boxes, masks, scores, iou_threshold=0.5, top_k=100):
if tf.rank(scores) == 1:
scores = tf.expand_dims(scores, axis=-1)
boxes = tf.expand_dims(boxes, axis=0)
masks = tf.expand_dims(masks, axis=0)

scores, idx = tf.math.top_k(scores, k=top_k)
num_classes, num_dets = tf.shape(idx)[0], tf.shape(idx)[1]
boxes = tf.gather(boxes, idx, axis=0)
masks = tf.gather(masks, idx, axis=0)
iou = utils._iou(boxes, boxes)
# upper trangular matrix - diagnoal
upper_triangular = tf.linalg.band_part(iou, 0, -1)
diag = tf.linalg.band_part(iou, 0, 0)
iou = upper_triangular - diag

# fitler out the unwanted ROI
iou_max = tf.reduce_max(iou, axis=1)
idx_det = (iou_max <= iou_threshold)

# second threshold
# second_threshold = (iou_max <= self.conf_threshold)
second_threshold = (scores > self.conf_threshold)
idx_det = tf.where(tf.logical_and(idx_det, second_threshold) == True)
classes = tf.broadcast_to(tf.expand_dims(tf.range(num_classes), axis=-1), tf.shape(iou_max))
classes = tf.gather_nd(classes, idx_det)
boxes = tf.gather_nd(boxes, idx_det)
masks = tf.gather_nd(masks, idx_det)
scores = tf.gather_nd(scores, idx_det)

# number of max detection = 100 (u can choose whatever u want)
max_num_detection = tf.math.minimum(self.max_num_detection, tf.size(scores))
scores, idx = tf.math.top_k(scores, k=max_num_detection)

# second threshold
classes = tf.gather(classes, idx)
boxes = tf.gather(boxes, idx)
masks = tf.gather(masks, idx)

return boxes, masks, classes, scores
Loading

0 comments on commit 1b7d28b

Please sign in to comment.