Skip to content

Commit

Permalink
[Refactor]Refactor exporting One-Stage model to ONNX (open-mmlab#6003)
Browse files Browse the repository at this point in the history
* Refactor one-stage get_bboxes logic (open-mmlab#5317)

* revert batch to single

* update anchor_head

* replace preds with bboxes

* add point_bbox_coder

* FCOS add get_selected_priori

* unified anchor-free and anchor-based get_bbox_single

* update code

* update reppoints and sabl

* add sparse priors

* add mlvlpointsgenerator

* revert __init__ of core

* refactor reppoints

* delete label channal

* add docstr

* fix typo

* fix args

* fix typo

* fix doc

* fix stride_h

* add offset

* Unified bbox coder

* add offset

* remove point_bbox_coder.py

* fix docstr

* new interface of single_proir

* fix device

* add unitest

* add cuda unitest

* add more cuda unintest

* fix reppoints

* fix device

* update all prior

* update vfnet

* add unintest for ssd and yolo and rename prior_idxs

* add docstr for MlvlPointGenerator

* update reppoints and rpnhead

* add space

* add num_base_priors

* update some model

* update docstr

* fixAugFPN test and lint.

* Fix autoassign

* add docs

* Unified fcos decoding

* update docstr

* fix train error

* Fix Vfnet

* Fix some

* update centernet

* revert

* add warnings

* fix unittest error

* delete duplicated

* fix comment

* fix docs

* fix type

Co-authored-by: zhangshilong <[email protected]>

* support onnx export for fcos

* support onnx export for fcos fsaf retina and ssd

* resolve comments

* resolve comments

* add with nms

* support cornernet

* resolve comments

* add default with nms

* fix trt arrange should be int

Co-authored-by: Haian Huang(深度眸) <[email protected]>
  • Loading branch information
2 people authored and ZwwWayne committed Oct 28, 2021
1 parent 75238f3 commit 09e71ff
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 30 deletions.
23 changes: 16 additions & 7 deletions mmdet/core/anchor/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,14 @@ def _meshgrid(self, x, y, row_major=True):
else:
return yy, xx

def grid_priors(self, featmap_sizes, device='cuda'):
def grid_priors(self, featmap_sizes, dtype=torch.float32, device='cuda'):
"""Generate grid anchors in multiple feature levels.
Args:
featmap_sizes (list[tuple]): List of feature map sizes in
multiple feature levels.
dtype (:obj:`torch.dtype`): Dtype of priors.
Default: torch.float32.
device (str): The device where the anchors will be put on.
Return:
Expand All @@ -232,11 +234,15 @@ def grid_priors(self, featmap_sizes, device='cuda'):
multi_level_anchors = []
for i in range(self.num_levels):
anchors = self.single_level_grid_priors(
featmap_sizes[i], level_idx=i, device=device)
featmap_sizes[i], level_idx=i, dtype=dtype, device=device)
multi_level_anchors.append(anchors)
return multi_level_anchors

def single_level_grid_priors(self, featmap_size, level_idx, device='cuda'):
def single_level_grid_priors(self,
featmap_size,
level_idx,
dtype=torch.float32,
device='cuda'):
"""Generate grid anchors of a single level.
Note:
Expand All @@ -245,22 +251,25 @@ def single_level_grid_priors(self, featmap_size, level_idx, device='cuda'):
Args:
featmap_size (tuple[int]): Size of the feature maps.
level_idx (int): The index of corresponding feature map level.
dtype (obj:`torch.dtype`): Date type of points.Defaults to
``torch.float32``.
device (str, optional): The device the tensor will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: Anchors in the overall feature maps.
"""

base_anchors = self.base_anchors[level_idx].to(device)
base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
feat_h, feat_w = featmap_size
stride_w, stride_h = self.strides[level_idx]
shift_x = torch.arange(0, feat_w, device=device) * stride_w
shift_y = torch.arange(0, feat_h, device=device) * stride_h
# First create Range with the default dtype, than convert to
# target `dtype` for onnx exporting.
shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h

shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
shifts = shifts.type_as(base_anchors)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
Expand Down
36 changes: 27 additions & 9 deletions mmdet/core/anchor/point_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,27 @@ def num_base_priors(self):
return [1 for _ in range(len(self.strides))]

def _meshgrid(self, x, y, row_major=True):
xx = x.repeat(len(y))
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
yy, xx = torch.meshgrid(y, x)
if row_major:
return xx, yy
# warning .flatten() would cause error in ONNX exporting
# have to use reshape here
return xx.reshape(-1), yy.reshape(-1)

else:
return yy, xx
return yy.reshape(-1), xx.reshape(-1)

def grid_priors(self, featmap_sizes, device='cuda', with_stride=False):
def grid_priors(self,
featmap_sizes,
dtype=torch.float32,
device='cuda',
with_stride=False):
"""Generate grid points of multiple feature levels.
Args:
featmap_sizes (list[tuple]): List of feature map sizes in
multiple feature levels, each size arrange as
as (h, w).
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
device (str): The device where the anchors will be put on.
with_stride (bool): Whether to concatenate the stride to
the last dimension of points.
Expand All @@ -96,12 +103,14 @@ def grid_priors(self, featmap_sizes, device='cuda', with_stride=False):
and the last dimension 4 represent
(coord_x, coord_y, stride_w, stride_h).
"""

assert self.num_levels == len(featmap_sizes)
multi_level_priors = []
for i in range(self.num_levels):
priors = self.single_level_grid_priors(
featmap_sizes[i],
level_idx=i,
dtype=dtype,
device=device,
with_stride=with_stride)
multi_level_priors.append(priors)
Expand All @@ -110,6 +119,7 @@ def grid_priors(self, featmap_sizes, device='cuda', with_stride=False):
def single_level_grid_priors(self,
featmap_size,
level_idx,
dtype=torch.float32,
device='cuda',
with_stride=False):
"""Generate grid Points of a single level.
Expand All @@ -121,6 +131,7 @@ def single_level_grid_priors(self,
featmap_size (tuple[int]): Size of the feature maps, arrange as
(h, w).
level_idx (int): The index of corresponding feature map level.
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
device (str, optional): The device the tensor will be put on.
Defaults to 'cuda'.
with_stride (bool): Concatenate the stride to the last dimension
Expand All @@ -138,16 +149,23 @@ def single_level_grid_priors(self,
"""
feat_h, feat_w = featmap_size
stride_w, stride_h = self.strides[level_idx]
shift_x = (torch.arange(0., feat_w, device=device) +
shift_x = (torch.arange(0, feat_w, device=device) +
self.offset) * stride_w
shift_y = (torch.arange(0., feat_h, device=device) +
# keep featmap_size as Tensor instead of int, so that we
# can covert to ONNX correctly
shift_x = shift_x.to(dtype)

shift_y = (torch.arange(0, feat_h, device=device) +
self.offset) * stride_h
# keep featmap_size as Tensor instead of int, so that we
# can covert to ONNX correctly
shift_y = shift_y.to(dtype)
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
if not with_stride:
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
else:
stride_w = shift_xx.new_full((len(shift_xx), ), stride_w)
stride_h = shift_xx.new_full((len(shift_yy), ), stride_h)
stride_w = shift_xx.new_full(shift_xx.shape[0], stride_w).to(dtype)
stride_h = shift_xx.new_full(shift_yy.shape[0], stride_h).to(dtype)
shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h],
dim=-1)
all_points = shifts.to(device)
Expand Down
1 change: 1 addition & 0 deletions mmdet/core/bbox/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def distance2bbox(points, distance, max_shape=None):
Returns:
Tensor: Boxes with shape (N, 4) or (B, N, 4)
"""

x1 = points[..., 0] - distance[..., 0]
y1 = points[..., 1] - distance[..., 1]
x2 = points[..., 0] + distance[..., 2]
Expand Down
4 changes: 2 additions & 2 deletions mmdet/core/export/onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def add_dummy_nms_for_onnx(boxes,
num_classed. Defaults to None.
Returns:
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] and class labels
of shape [N, num_det].
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
and class labels of shape [N, num_det].
"""
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
Expand Down
167 changes: 166 additions & 1 deletion mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _get_bboxes_single(self,
score_factor = score_factor[topk_inds]
else:
priors = self.prior_generator.single_level_grid_priors(
featmap_size_hw, level_idx, scores.device)
featmap_size_hw, level_idx, scores.dtype, scores.device)

bboxes = self.bbox_coder.decode(
priors, bbox_pred, max_shape=img_shape)
Expand Down Expand Up @@ -337,3 +337,168 @@ def simple_test(self, feats, img_metas, rescale=False):
with shape (n, ).
"""
return self.simple_test_bboxes(feats, img_metas, rescale=rescale)

@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def onnx_export(self,
cls_scores,
bbox_preds,
score_factors=None,
img_metas=None,
with_nms=True):
"""Transform network output for a batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
with shape (N, num_points * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_points * 4, H, W).
score_factors (list[Tensor]): score_factors for each s
cale level with shape (N, num_points * 1, H, W).
Default: None.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. Default: None.
with_nms (bool): Whether apply nms to the bboxes. Default: True.
Returns:
tuple[Tensor, Tensor] | list[tuple]: When `with_nms` is True,
it is tuple[Tensor, Tensor], first tensor bboxes with shape
[N, num_det, 5], 5 arrange as (x1, y1, x2, y2, score)
and second element is class labels of shape [N, num_det].
When `with_nms` is False, first tensor is bboxes with
shape [N, num_det, 4], second tensor is raw score has
shape [N, num_det, num_classes].
"""
assert len(cls_scores) == len(bbox_preds)

num_levels = len(cls_scores)

featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(featmap_sizes,
bbox_preds[0].dtype,
bbox_preds[0].device)

mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]

assert len(
img_metas
) == 1, 'Only support one input image while in exporting to ONNX'
img_shape = img_metas[0]['img_shape_for_onnx']

cfg = self.test_cfg
assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
device = cls_scores[0].device
batch_size = cls_scores[0].shape[0]
# convert to tensor to keep tracing
nms_pre_tensor = torch.tensor(
cfg.get('nms_pre', -1), device=device, dtype=torch.long)

# e.g. Retina, FreeAnchor, etc.
if score_factors is None:
with_score_factors = False
mlvl_score_factor = [None for _ in range(num_levels)]
else:
# e.g. FCOS, PAA, ATSS, etc.
with_score_factors = True
mlvl_score_factor = [
score_factors[i].detach() for i in range(num_levels)
]
mlvl_score_factors = []

mlvl_batch_bboxes = []
mlvl_scores = []

for cls_score, bbox_pred, score_factors, priors in zip(
mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor,
mlvl_priors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]

scores = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
self.cls_out_channels)
if self.use_sigmoid_cls:
scores = scores.sigmoid()
nms_pre_score = scores
else:
scores = scores.softmax(-1)
nms_pre_score = scores

if with_score_factors:
score_factors = score_factors.permute(0, 2, 3, 1).reshape(
batch_size, -1).sigmoid()
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(batch_size, -1, 4)
priors = priors.expand(batch_size, -1, priors.size(-1))
# Get top-k predictions
from mmdet.core.export import get_k_for_topk
nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
if nms_pre > 0:

if with_score_factors:
nms_pre_score = (nms_pre_score * score_factors[..., None])
else:
nms_pre_score = nms_pre_score

# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = nms_pre_score.max(-1)
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(nms_pre)

batch_inds = torch.arange(
batch_size, device=bbox_pred.device).view(
-1, 1).expand_as(topk_inds).long()
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
transformed_inds = bbox_pred.shape[1] * batch_inds + topk_inds
priors = priors.reshape(
-1, priors.size(-1))[transformed_inds, :].reshape(
batch_size, -1, priors.size(-1))
bbox_pred = bbox_pred.reshape(-1,
4)[transformed_inds, :].reshape(
batch_size, -1, 4)
scores = scores.reshape(
-1, self.cls_out_channels)[transformed_inds, :].reshape(
batch_size, -1, self.cls_out_channels)
if with_score_factors:
score_factors = score_factors.reshape(
-1, 1)[transformed_inds].reshape(batch_size, -1)

bboxes = self.bbox_coder.decode(
priors, bbox_pred, max_shape=img_shape)

mlvl_batch_bboxes.append(bboxes)
mlvl_scores.append(scores)
if with_score_factors:
mlvl_score_factors.append(score_factors)

batch_bboxes = torch.cat(mlvl_batch_bboxes, dim=1)
batch_scores = torch.cat(mlvl_scores, dim=1)
if with_score_factors:
batch_score_factors = torch.cat(mlvl_score_factors, dim=1)

# Replace multiclass_nms with ONNX::NonMaxSuppression in deployment

from mmdet.core.export import add_dummy_nms_for_onnx

if not self.use_sigmoid_cls:
batch_scores = batch_scores[..., :self.num_classes]

if with_score_factors:
batch_scores = batch_scores * (batch_score_factors.unsqueeze(2))

if with_nms:
max_output_boxes_per_class = cfg.nms.get(
'max_output_boxes_per_class', 200)
iou_threshold = cfg.nms.get('iou_threshold', 0.5)
score_threshold = cfg.score_thr
nms_pre = cfg.get('deploy_nms_pre', -1)
return add_dummy_nms_for_onnx(batch_bboxes, batch_scores,
max_output_boxes_per_class,
iou_threshold, score_threshold,
nms_pre, cfg.max_per_img)
else:
return batch_bboxes, batch_scores
Loading

0 comments on commit 09e71ff

Please sign in to comment.