Skip to content

Commit

Permalink
cc_fast-nms bug fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
anshkumar committed May 27, 2022
1 parent 79340e3 commit 8d638dc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
15 changes: 7 additions & 8 deletions detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def _sanitize(self, boxes, width, height, padding: int = 0):
- masks should be a size [h, w, n] tensor of masks
- boxes should be a size [n, 4] tensor of bbox coords in relative point form
"""
x1, x2 = utils.sanitize_coordinates(boxes[:, 1], boxes[:, 3], width, normalized=False)
y1, y2 = utils.sanitize_coordinates(boxes[:, 0], boxes[:, 2], height, normalized=False)
x1, x2 = utils.sanitize_coordinates(boxes[:, 1], boxes[:, 3], tf.cast(width, dtype=tf.float32), normalized=False)
y1, y2 = utils.sanitize_coordinates(boxes[:, 0], boxes[:, 2], tf.cast(height, dtype=tf.float32), normalized=False)

boxes = tf.stack((y1, x1, y2, x2), axis=1)

Expand Down Expand Up @@ -199,13 +199,12 @@ def _cc_fast_nms(self, boxes, masks, scores, iou_threshold:float=0.5, top_k:int=

# Now just filter out the ones greater than the threshold, i.e., only keep boxes that
# don't have a higher scoring box that would supress it in normal NMS.
idx_det = (iou_max <= iou_threshold)
idx_det = tf.where(idx_det == True)
idx_out = idx[iou_max <= iou_threshold]

classes = tf.gather_nd(classes, idx_det)
boxes = tf.gather_nd(boxes, idx_det)
masks = tf.gather_nd(masks, idx_det)
scores = tf.gather_nd(scores, idx_det)
classes = tf.cast(tf.gather_nd(classes, tf.expand_dims(idx_out, axis=-1)), dtype=tf.float32)
boxes = tf.gather_nd(boxes, tf.expand_dims(idx_out, axis=-1))
masks = tf.gather_nd(masks, tf.expand_dims(idx_out, axis=-1))
scores = tf.gather_nd(scores, tf.expand_dims(idx_out, axis=-1))

return boxes, masks, classes, scores

Expand Down
6 changes: 3 additions & 3 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def map_to_offset(x):
g_hat_h = tf.math.log(x[3, 0] / x[3, 1])
return tf.stack([g_hat_cx, g_hat_cy, g_hat_w, g_hat_h])

def sanitize_coordinates(_x1, _x2, img_size, normalized, padding = 0):
def sanitize_coordinates(_x1, _x2, img_size, normalized, padding = 0.0):
"""
Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0, and x2 <= image_size.
Also converts from relative to absolute coordinates and casts the results to long tensors.
Expand All @@ -98,8 +98,8 @@ def sanitize_coordinates(_x1, _x2, img_size, normalized, padding = 0):

x1 = tf.math.minimum(_x1, _x2)
x2 = tf.math.maximum(_x1, _x2)
x1 = tf.clip_by_value(x1 - padding, clip_value_min=0, clip_value_max=img_size)
x2 = tf.clip_by_value(x2 + padding, clip_value_min=0, clip_value_max=img_size)
x1 = tf.clip_by_value(x1 - padding, clip_value_min=0.0, clip_value_max=img_size)
x2 = tf.clip_by_value(x2 + padding, clip_value_min=0.0, clip_value_max=img_size)

return x1, x2

Expand Down

0 comments on commit 8d638dc

Please sign in to comment.