Skip to content

Commit

Permalink
Refactor dense_head and speedup (open-mmlab#6268)
Browse files Browse the repository at this point in the history
* update get_bboxes

* fix decode error

* speedup delta2bbox

* update others model

* fix ms_test

* fix unit test

* fix ms test

* fix yolox error and paa error

* fix lint

* speedup distance2bbox

* fix ci error of dtype

* fix docsstr

* fix type

* fix aug type

* rename variable

* fix comments

* fix type

* replace spare with priors

* fix rpn

* fix comment

* fix comment

* fix comment
  • Loading branch information
hhaAndroid authored and ZwwWayne committed Oct 28, 2021
1 parent 09e71ff commit a715478
Show file tree
Hide file tree
Showing 20 changed files with 629 additions and 237 deletions.
6 changes: 4 additions & 2 deletions mmdet/core/anchor/point_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ def single_level_grid_priors(self,
if not with_stride:
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
else:
stride_w = shift_xx.new_full(shift_xx.shape[0], stride_w).to(dtype)
stride_h = shift_xx.new_full(shift_yy.shape[0], stride_h).to(dtype)
stride_w = shift_xx.new_full((shift_xx.shape[0], ),
stride_w).to(dtype)
stride_h = shift_xx.new_full((shift_yy.shape[0], ),
stride_h).to(dtype)
shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h],
dim=-1)
all_points = shifts.to(device)
Expand Down
134 changes: 127 additions & 7 deletions mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
import numpy as np
import torch
Expand Down Expand Up @@ -88,9 +90,26 @@ def decode(self,
assert pred_bboxes.size(0) == bboxes.size(0)
if pred_bboxes.ndim == 3:
assert pred_bboxes.size(1) == bboxes.size(1)
decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds,
max_shape, wh_ratio_clip, self.clip_border,
self.add_ctr_clamp, self.ctr_clamp)

if pred_bboxes.ndim == 2 and not torch.onnx.is_in_onnx_export():
# single image decode
decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means,
self.stds, max_shape, wh_ratio_clip,
self.clip_border, self.add_ctr_clamp,
self.ctr_clamp)
else:
if pred_bboxes.ndim == 3 and not torch.onnx.is_in_onnx_export():
warnings.warn(
'DeprecationWarning: onnx_delta2bbox is deprecated '
'in the case of batch decoding and non-ONNX, '
'please use “delta2bbox” instead. In order to improve '
'the decoding speed, the batch function will no '
'longer be supported. ')
decoded_bboxes = onnx_delta2bbox(bboxes, pred_bboxes, self.means,
self.stds, max_shape,
wh_ratio_clip, self.clip_border,
self.add_ctr_clamp,
self.ctr_clamp)

return decoded_bboxes

Expand Down Expand Up @@ -157,23 +176,124 @@ def delta2bbox(rois,
network outputs used to shift/scale those boxes.
This is the inverse function of :func:`bbox2delta`.
Args:
rois (Tensor): Boxes to be transformed. Has shape (N, 4).
deltas (Tensor): Encoded offsets relative to each roi.
Has shape (N, num_classes * 4) or (N, 4). Note
N = num_base_anchors * W * H, when rois is a grid of
anchors. Offset encoding follows [1]_.
means (Sequence[float]): Denormalizing means for delta coordinates.
Default (0., 0., 0., 0.).
stds (Sequence[float]): Denormalizing standard deviation for delta
coordinates. Default (1., 1., 1., 1.).
max_shape (tuple[int, int]): Maximum bounds for boxes, specifies
(H, W). Default None.
wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
16 / 1000.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Default True.
add_ctr_clamp (bool): Whether to add center clamp. When set to True,
the center of the prediction bounding box will be clamped to
avoid being too far away from the center of the anchor.
Only used by YOLOF. Default False.
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
Default 32.
Returns:
Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4
represent tl_x, tl_y, br_x, br_y.
References:
.. [1] https://arxiv.org/abs/1311.2524
Example:
>>> rois = torch.Tensor([[ 0., 0., 1., 1.],
>>> [ 0., 0., 1., 1.],
>>> [ 0., 0., 1., 1.],
>>> [ 5., 5., 5., 5.]])
>>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
>>> [ 1., 1., 1., 1.],
>>> [ 0., 0., 2., -1.],
>>> [ 0.7, -1.9, -0.5, 0.3]])
>>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
tensor([[0.0000, 0.0000, 1.0000, 1.0000],
[0.1409, 0.1409, 2.8591, 2.8591],
[0.0000, 0.3161, 4.1945, 0.6839],
[5.0000, 5.0000, 5.0000, 5.0000]])
"""
num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4
if num_bboxes == 0:
return deltas

deltas = deltas.reshape(-1, 4)

means = deltas.new_tensor(means).view(1, -1)
stds = deltas.new_tensor(stds).view(1, -1)
denorm_deltas = deltas * stds + means

dxy = denorm_deltas[:, :2]
dwh = denorm_deltas[:, 2:]

# Compute width/height of each roi
rois_ = rois.repeat(1, num_classes).reshape(-1, 4)
pxy = ((rois_[:, :2] + rois_[:, 2:]) * 0.5)
pwh = (rois_[:, 2:] - rois_[:, :2])

dxy_wh = pwh * dxy

max_ratio = np.abs(np.log(wh_ratio_clip))
if add_ctr_clamp:
dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
dwh = torch.clamp(dwh, max=max_ratio)
else:
dwh = dwh.clamp(min=-max_ratio, max=max_ratio)

gxy = pxy + dxy_wh
gwh = pwh * dwh.exp()
x1y1 = gxy - (gwh * 0.5)
x2y2 = gxy + (gwh * 0.5)
bboxes = torch.cat([x1y1, x2y2], dim=-1)
if clip_border and max_shape is not None:
bboxes[..., 0::2].clamp_(min=0, max=max_shape[1])
bboxes[..., 1::2].clamp_(min=0, max=max_shape[0])
bboxes = bboxes.reshape(num_bboxes, -1)
return bboxes


def onnx_delta2bbox(rois,
deltas,
means=(0., 0., 0., 0.),
stds=(1., 1., 1., 1.),
max_shape=None,
wh_ratio_clip=16 / 1000,
clip_border=True,
add_ctr_clamp=False,
ctr_clamp=32):
"""Apply deltas to shift/scale base boxes.
Typically the rois are anchor or proposed bounding boxes and the deltas are
network outputs used to shift/scale those boxes.
This is the inverse function of :func:`bbox2delta`.
Args:
rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
deltas (Tensor): Encoded offsets with respect to each roi.
Has shape (B, N, num_classes * 4) or (B, N, 4) or
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
when rois is a grid of anchors.Offset encoding follows [1]_.
means (Sequence[float]): Denormalizing means for delta coordinates
means (Sequence[float]): Denormalizing means for delta coordinates.
Default (0., 0., 0., 0.).
stds (Sequence[float]): Denormalizing standard deviation for delta
coordinates
coordinates. Default (1., 1., 1., 1.).
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If rois shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
and the length of max_shape should also be B. Default None.
wh_ratio_clip (float): Maximum aspect ratio for boxes.
Default 16 / 1000.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
border of the image. Default True.
add_ctr_clamp (bool): Whether to add center clamp, when added, the
predicted box is clamped is its center is too far away from
the original anchor's center. Only used by YOLOF. Default False.
Expand Down
6 changes: 6 additions & 0 deletions mmdet/core/bbox/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ def distance2bbox(points, distance, max_shape=None):
bboxes = torch.stack([x1, y1, x2, y2], -1)

if max_shape is not None:
if points.dim() == 2 and not torch.onnx.is_in_onnx_export():
# speed up
bboxes[:, 0::2].clamp_(min=0, max=max_shape[1])
bboxes[:, 1::2].clamp_(min=0, max=max_shape[0])
return bboxes

# clip bboxes with dynamic `min` and `max` for onnx
if torch.onnx.is_in_onnx_export():
from mmdet.core.export import dynamic_clip_for_onnx
Expand Down
8 changes: 5 additions & 3 deletions mmdet/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads,
reduce_mean)
from .misc import (center_of_mass, flip_tensor, generate_coordinate,
mask2ndarray, multi_apply, select_single_mlvl, unmap)
from .misc import (center_of_mass, filter_scores_and_topk, flip_tensor,
generate_coordinate, mask2ndarray, multi_apply,
select_single_mlvl, unmap)

__all__ = [
'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply',
'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict',
'center_of_mass', 'generate_coordinate', 'select_single_mlvl'
'center_of_mass', 'generate_coordinate', 'select_single_mlvl',
'filter_scores_and_topk'
]
55 changes: 52 additions & 3 deletions mmdet/core/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
Cascade Mask R-CNN.
Args:
mlvl_tensors (list[Tensor]):Batch tensor for all scale levels,
mlvl_tensors (list[Tensor]): Batch tensor for all scale levels,
each is a 4D-tensor.
batch_id (int): batch index.
batch_id (int): Batch index.
detach (bool): Whether detach gradient. Default True.
Returns:
list[Tensor]: multi-scale single image tensor.
list[Tensor]: Multi-scale single image tensor.
"""
assert isinstance(mlvl_tensors, (list, tuple))
num_levels = len(mlvl_tensors)
Expand All @@ -116,6 +116,55 @@ def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
return mlvl_tensor_list


def filter_scores_and_topk(scores, score_thr, topk, results=None):
"""Filter results using score threshold and topk candidates.
Args:
scores (Tensor): The scores, shape (num_bboxes, K).
score_thr (float): The score filter threshold.
topk (int): The number of topk candidates.
results (dict or list or Tensor, Optional): The results to
which the filtering rule is to be applied. The shape
of each item is (num_bboxes, N).
Returns:
tuple: Filtered results
- scores (Tensor): The scores after being filtered, \
shape (num_bboxes_filtered, ).
- labels (Tensor): The class labels, shape \
(num_bboxes_filtered, ).
- anchor_idxs (Tensor): The anchor indexes, shape \
(num_bboxes_filtered, ).
- filtered_results (dict or list or Tensor, Optional): \
The filtered results. The shape of each item is \
(num_bboxes_filtered, N).
"""
valid_mask = scores > score_thr
scores = scores[valid_mask]
valid_idxs = torch.nonzero(valid_mask)

num_topk = min(topk, valid_idxs.size(0))
# torch.sort is actually faster than .topk (at least on GPUs)
scores, idxs = scores.sort(descending=True)
scores = scores[:num_topk]
topk_idxs = valid_idxs[idxs[:num_topk]]
keep_idxs, labels = topk_idxs.unbind(dim=1)

filtered_results = None
if results is not None:
if isinstance(results, dict):
filtered_results = {k: v[keep_idxs] for k, v in results.items()}
elif isinstance(results, list):
filtered_results = [result[keep_idxs] for result in results]
elif isinstance(results, torch.Tensor):
filtered_results = results[keep_idxs]
else:
raise NotImplementedError(f'Only supports dict or list or Tensor, '
f'but get {type(results)}.')
return scores, labels, keep_idxs, filtered_results


def center_of_mass(mask, esp=1e-6):
"""Calculate the centroid coordinates of the mask.
Expand Down
Loading

0 comments on commit a715478

Please sign in to comment.