Skip to content

Commit

Permalink
Fix the handling of images with no GT
Browse files Browse the repository at this point in the history
Summary: Make PointRend mask loss to work with images without GT instances. The logic closely follows mask loss for Mask R-CNN now.

Reviewed By: ppwwyyxx

Differential Revision: D21711878

fbshipit-source-id: d79b6e0a2d00e06c83796aa7257ab00faac06acf
  • Loading branch information
Alexander Kirillov authored and facebook-github-bot committed May 25, 2020
1 parent de09842 commit 426d239
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions projects/PointRend/point_rend/point_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def roi_mask_point_loss(mask_logits, instances, points_coord):
Returns:
point_loss (Tensor): A scalar tensor containing the loss.
"""
assert len(instances) == 0 or isinstance(
instances[0].gt_masks, BitMasks
), "Point head works with GT in 'bitmask' format only. Set INPUT.MASK_FORMAT to 'bitmask'."
with torch.no_grad():
cls_agnostic_mask = mask_logits.size(1) == 1
total_num_masks = mask_logits.size(0)
Expand All @@ -50,6 +47,12 @@ def roi_mask_point_loss(mask_logits, instances, points_coord):
gt_mask_logits = []
idx = 0
for instances_per_image in instances:
if len(instances_per_image) == 0:
continue
assert isinstance(
instances_per_image.gt_masks, BitMasks
), "Point head works with GT in 'bitmask' format. Set INPUT.MASK_FORMAT to 'bitmask'."

if not cls_agnostic_mask:
gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64)
gt_classes.append(gt_classes_per_image)
Expand All @@ -68,13 +71,13 @@ def roi_mask_point_loss(mask_logits, instances, points_coord):
align_corners=False,
).squeeze(1)
)
gt_mask_logits = cat(gt_mask_logits)

# torch.mean (in binary_cross_entropy_with_logits) doesn't
# accept empty tensors, so handle it separately
if gt_mask_logits.numel() == 0:
if len(gt_mask_logits) == 0:
return mask_logits.sum() * 0

gt_mask_logits = cat(gt_mask_logits)
assert gt_mask_logits.numel() > 0, gt_mask_logits.shape

if cls_agnostic_mask:
mask_logits = mask_logits[:, 0]
else:
Expand Down

0 comments on commit 426d239

Please sign in to comment.