Skip to content

Commit

Permalink
Refactor one-stage get_bboxes logic (open-mmlab#5317)
Browse files Browse the repository at this point in the history
* revert batch to single

* update anchor_head

* replace preds with bboxes

* add point_bbox_coder

* FCOS add get_selected_priori

* unified anchor-free and anchor-based get_bbox_single

* update code

* update reppoints and sabl

* add sparse priors

* add mlvlpointsgenerator

* revert __init__ of core

* refactor reppoints

* delete label channal

* add docstr

* fix typo

* fix args

* fix typo

* fix doc

* fix stride_h

* add offset

* Unified bbox coder

* add offset

* remove point_bbox_coder.py

* fix docstr

* new interface of single_proir

* fix device

* add unitest

* add cuda unitest

* add more cuda unintest

* fix reppoints

* fix device

* update all prior

* update vfnet

* add unintest for ssd and yolo and rename prior_idxs

* add docstr for MlvlPointGenerator

* update reppoints and rpnhead

* add space

* add num_base_priors

* update some model

* update docstr

* fixAugFPN test and lint.

* Fix autoassign

* add docs

* Unified fcos decoding

* update docstr

* fix train error

* Fix Vfnet

* Fix some

* update centernet

* revert

* add warnings

* fix unittest error

* delete duplicated

* fix comment

* fix docs

* fix type

Co-authored-by: zhangshilong <[email protected]>
  • Loading branch information
2 people authored and ZwwWayne committed Oct 28, 2021
1 parent bd460fe commit 75238f3
Show file tree
Hide file tree
Showing 31 changed files with 1,028 additions and 1,427 deletions.
10 changes: 5 additions & 5 deletions mmdet/core/bbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner,
MaxIoUAssigner, RegionAssigner)
from .builder import build_assigner, build_bbox_coder, build_sampler
from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder,
TBLRBBoxCoder)
from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, DistancePointBBoxCoder,
PseudoBBoxCoder, TBLRBBoxCoder)
from .iou_calculators import BboxOverlaps2D, bbox_overlaps
from .samplers import (BaseSampler, CombinedSampler,
InstanceBalancedPosSampler, IoUBalancedNegSampler,
Expand All @@ -22,7 +22,7 @@
'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner',
'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh',
'RegionAssigner'
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'DistancePointBBoxCoder',
'CenterRegionAssigner', 'bbox_rescale', 'bbox_cxcywh_to_xyxy',
'bbox_xyxy_to_cxcywh', 'RegionAssigner'
]
3 changes: 2 additions & 1 deletion mmdet/core/bbox/coder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .base_bbox_coder import BaseBBoxCoder
from .bucketing_bbox_coder import BucketingBBoxCoder
from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from .distance_point_bbox_coder import DistancePointBBoxCoder
from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder
from .pseudo_bbox_coder import PseudoBBoxCoder
from .tblr_bbox_coder import TBLRBBoxCoder
Expand All @@ -10,5 +11,5 @@
__all__ = [
'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder',
'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder',
'BucketingBBoxCoder'
'BucketingBBoxCoder', 'DistancePointBBoxCoder'
]
62 changes: 62 additions & 0 deletions mmdet/core/bbox/coder/distance_point_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from ..builder import BBOX_CODERS
from ..transforms import bbox2distance, distance2bbox
from .base_bbox_coder import BaseBBoxCoder


@BBOX_CODERS.register_module()
class DistancePointBBoxCoder(BaseBBoxCoder):
"""Distance Point BBox coder.
This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
right) and decode it back to the original.
Args:
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
"""

def __init__(self, clip_border=True):
super(BaseBBoxCoder, self).__init__()
self.clip_border = clip_border

def encode(self, points, gt_bboxes, max_dis=None, eps=0.1):
"""Encode bounding box to distances.
Args:
points (Tensor): Shape (N, 2), The format is [x, y].
gt_bboxes (Tensor): Shape (N, 4), The format is "xyxy"
max_dis (float): Upper bound of the distance. Default None.
eps (float): a small value to ensure target < max_dis, instead <=.
Default 0.1.
Returns:
Tensor: Box transformation deltas. The shape is (N, 4).
"""
assert points.size(0) == gt_bboxes.size(0)
assert points.size(-1) == 2
assert gt_bboxes.size(-1) == 4
return bbox2distance(points, gt_bboxes, max_dis, eps)

def decode(self, points, pred_bboxes, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (B, N, 2) or (N, 2).
pred_bboxes (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom). Shape (B, N, 4)
or (N, 4)
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If priors shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]],
and the length of max_shape should also be B.
Default None.
Returns:
Tensor: Boxes with shape (N, 4) or (B, N, 4)
"""
assert points.size(0) == pred_bboxes.size(0)
assert points.size(-1) == 2
assert pred_bboxes.size(-1) == 4
if self.clip_border is False:
max_shape = None
return distance2bbox(points, pred_bboxes, max_shape)
4 changes: 2 additions & 2 deletions mmdet/core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads,
reduce_mean)
from .misc import (center_of_mass, flip_tensor, generate_coordinate,
mask2ndarray, multi_apply, unmap)
mask2ndarray, multi_apply, select_single_mlvl, unmap)

__all__ = [
'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply',
'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict',
'center_of_mass', 'generate_coordinate'
'center_of_mass', 'generate_coordinate', 'select_single_mlvl'
]
31 changes: 31 additions & 0 deletions mmdet/core/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,37 @@ def flip_tensor(src_tensor, flip_direction):
return out_tensor


def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
"""Extract a multi-scale single image tensor from a multi-scale batch
tensor based on batch index.
Note: The default value of detach is True, because the proposal gradient
needs to be detached during the training of the two-stage model. E.g
Cascade Mask R-CNN.
Args:
mlvl_tensors (list[Tensor]):Batch tensor for all scale levels,
each is a 4D-tensor.
batch_id (int): batch index.
detach (bool): Whether detach gradient. Default True.
Returns:
list[Tensor]: multi-scale single image tensor.
"""
assert isinstance(mlvl_tensors, (list, tuple))
num_levels = len(mlvl_tensors)

if detach:
mlvl_tensor_list = [
mlvl_tensors[i][batch_id].detach() for i in range(num_levels)
]
else:
mlvl_tensor_list = [
mlvl_tensors[i][batch_id] for i in range(num_levels)
]
return mlvl_tensor_list


def center_of_mass(mask, esp=1e-6):
"""Calculate the centroid coordinates of the mask.
Expand Down
38 changes: 12 additions & 26 deletions mmdet/models/dense_heads/anchor_free_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32

from mmdet.core import multi_apply
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.core.anchor.point_generator import MlvlPointGenerator
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
Expand All @@ -30,6 +31,8 @@ class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
None, otherwise False. Default: "auto".
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
bbox_coder (dict): Config of bbox coder. Defaults
'DistancePointBBoxCoder'.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
train_cfg (dict): Training config of anchor head.
Expand All @@ -54,6 +57,7 @@ def __init__(self,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
bbox_coder=dict(type='DistancePointBBoxCoder'),
conv_cfg=None,
norm_cfg=None,
train_cfg=None,
Expand All @@ -69,7 +73,11 @@ def __init__(self,
bias_prob=0.01))):
super(AnchorFreeHead, self).__init__(init_cfg)
self.num_classes = num_classes
self.cls_out_channels = num_classes
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes + 1
self.in_channels = in_channels
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
Expand All @@ -79,6 +87,8 @@ def __init__(self,
self.conv_bias = conv_bias
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.prior_generator = MlvlPointGenerator(strides)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.conv_cfg = conv_cfg
Expand Down Expand Up @@ -247,30 +257,6 @@ def loss(self,

raise NotImplementedError

@abstractmethod
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg=None,
rescale=None):
"""Transform network output for a batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_points * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_points * 4, H, W)
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
rescale (bool): If True, return boxes in original image space
"""

raise NotImplementedError

@abstractmethod
def get_targets(self, points, gt_bboxes_list, gt_labels_list):
"""Compute regression, classification and centerness targets for points
Expand Down
Loading

0 comments on commit 75238f3

Please sign in to comment.