Skip to content

Commit

Permalink
[refactor] Training of densehead (open-mmlab#6315)
Browse files Browse the repository at this point in the history
* Refactor one-stage get_bboxes logic (open-mmlab#5317)

* revert batch to single

* update anchor_head

* replace preds with bboxes

* add point_bbox_coder

* FCOS add get_selected_priori

* unified anchor-free and anchor-based get_bbox_single

* update code

* update reppoints and sabl

* add sparse priors

* add mlvlpointsgenerator

* revert __init__ of core

* refactor reppoints

* delete label channal

* add docstr

* fix typo

* fix args

* fix typo

* fix doc

* fix stride_h

* add offset

* Unified bbox coder

* add offset

* remove point_bbox_coder.py

* fix docstr

* new interface of single_proir

* fix device

* add unitest

* add cuda unitest

* add more cuda unintest

* fix reppoints

* fix device

* update all prior

* update vfnet

* add unintest for ssd and yolo and rename prior_idxs

* add docstr for MlvlPointGenerator

* update reppoints and rpnhead

* add space

* add num_base_priors

* update some model

* update docstr

* fixAugFPN test and lint.

* Fix autoassign

* add docs

* Unified fcos decoding

* update docstr

* fix train error

* Fix Vfnet

* Fix some

* update centernet

* revert

* add warnings

* fix unittest error

* delete duplicated

* fix comment

* fix docs

* fix type

Co-authored-by: zhangshilong <[email protected]>

* support onnx export for fcos

* support onnx export for fcos fsaf retina and ssd

* resolve comments

* resolve comments

* add with nms

* support cornernet

* resolve comments

* add default with nms

* fix trt arrange should be int

* refactor anchor head anchor free head

* add dtype to single_level_grid_priors

* atss fcos autoassign

* fovea

* fsaf free anchor

* suport more

* suport more

* support all

* resolve conversation

* fix point generator

* fix device

* change to distancecoder

* resolve conversation

* fix grid prior

* fix typos in autoassgin

* fix typos

* fix doc

Co-authored-by: Haian Huang(深度眸) <[email protected]>
  • Loading branch information
2 people authored and ZwwWayne committed Oct 28, 2021
1 parent a715478 commit 82c4e77
Show file tree
Hide file tree
Showing 27 changed files with 436 additions and 232 deletions.
4 changes: 2 additions & 2 deletions mmdet/core/anchor/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ class AnchorGenerator:
Examples:
>>> from mmdet.core import AnchorGenerator
>>> self = AnchorGenerator([16], [1.], [1.], [9])
>>> all_anchors = self.grid_anchors([(2, 2)], device='cpu')
>>> all_anchors = self.grid_priors([(2, 2)], device='cpu')
>>> print(all_anchors)
[tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
[11.5000, -4.5000, 20.5000, 4.5000],
[-4.5000, 11.5000, 4.5000, 20.5000],
[11.5000, 11.5000, 20.5000, 20.5000]])]
>>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
>>> all_anchors = self.grid_anchors([(2, 2), (1, 1)], device='cpu')
>>> all_anchors = self.grid_priors([(2, 2), (1, 1)], device='cpu')
>>> print(all_anchors)
[tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
[11.5000, -4.5000, 20.5000, 4.5000],
Expand Down
1 change: 1 addition & 0 deletions mmdet/core/anchor/point_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def single_level_grid_priors(self,
if not with_stride:
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
else:
# use `shape[0]` instead of `len(shift_xx)` for ONNX export
stride_w = shift_xx.new_full((shift_xx.shape[0], ),
stride_w).to(dtype)
stride_h = shift_xx.new_full((shift_yy.shape[0], ),
Expand Down
25 changes: 24 additions & 1 deletion mmdet/models/dense_heads/anchor_free_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import abstractmethod

import torch
Expand Down Expand Up @@ -88,7 +89,13 @@ def __init__(self,
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)

self.prior_generator = MlvlPointGenerator(strides)

# In order to keep a more general interface and be consistent with
# anchor_head. We can think of point like one anchor
self.num_base_priors = self.prior_generator.num_base_priors[0]

self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.conv_cfg = conv_cfg
Expand Down Expand Up @@ -278,7 +285,17 @@ def _get_points_single(self,
dtype,
device,
flatten=False):
"""Get points of a single scale level."""
"""Get points of a single scale level.
This function will be deprecated soon.
"""

warnings.warn(
'`_get_points_single` in `AnchorFreeHead` will be '
'deprecated soon, we support a multi level point generator now'
'you can get points of a single level feature map '
'with `self.prior_generator.single_level_grid_priors` ')

h, w = featmap_size
# First create Range with the default dtype, than convert to
# target `dtype` for onnx exporting.
Expand All @@ -301,6 +318,12 @@ def get_points(self, featmap_sizes, dtype, device, flatten=False):
Returns:
tuple: points of each image.
"""
warnings.warn(
'`get_points` in `AnchorFreeHead` will be '
'deprecated soon, we support a multi level point generator now'
'you can get points of all levels '
'with `self.prior_generator.grid_priors` ')

mlvl_points = []
for i in range(len(featmap_sizes)):
mlvl_points.append(
Expand Down
33 changes: 22 additions & 11 deletions mmdet/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,20 @@ def __init__(self,
self.fp16_enabled = False

self.prior_generator = build_prior_generator(anchor_generator)
# usually the numbers of anchors for each level are the same
# except SSD detectors
self.num_anchors = self.prior_generator.num_base_priors[0]

# Usually the numbers of anchors for each level are the same
# except SSD detectors. So it is an int in the most dense
# heads but a list of int in SSDHead
self.num_base_priors = self.prior_generator.num_base_priors[0]
self._init_layers()

@property
def num_anchors(self):
warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
'for consistency or also use '
'`num_base_priors` instead')
return self.prior_generator.num_base_priors[0]

@property
def anchor_generator(self):
warnings.warn('DeprecationWarning: anchor_generator is deprecated, '
Expand All @@ -118,8 +127,10 @@ def anchor_generator(self):
def _init_layers(self):
"""Initialize layers of the head."""
self.conv_cls = nn.Conv2d(self.in_channels,
self.num_anchors * self.cls_out_channels, 1)
self.conv_reg = nn.Conv2d(self.in_channels, self.num_anchors * 4, 1)
self.num_base_priors * self.cls_out_channels,
1)
self.conv_reg = nn.Conv2d(self.in_channels, self.num_base_priors * 4,
1)

def forward_single(self, x):
"""Forward feature of a single scale level.
Expand All @@ -130,9 +141,9 @@ def forward_single(self, x):
Returns:
tuple:
cls_score (Tensor): Cls scores for a single scale level \
the channels number is num_anchors * num_classes.
the channels number is num_base_priors * num_classes.
bbox_pred (Tensor): Box energies / deltas for a single scale \
level, the channels number is num_anchors * 4.
level, the channels number is num_base_priors * 4.
"""
cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
Expand All @@ -150,10 +161,10 @@ def forward(self, feats):
- cls_scores (list[Tensor]): Classification scores for all \
scale levels, each is a 4D-tensor, the channels number \
is num_anchors * num_classes.
is num_base_priors * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for all \
scale levels, each is a 4D-tensor, the channels number \
is num_anchors * 4.
is num_base_priors * 4.
"""
return multi_apply(self.forward_single, feats)

Expand All @@ -174,8 +185,8 @@ def get_anchors(self, featmap_sizes, img_metas, device='cuda'):

# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = self.prior_generator.grid_anchors(
featmap_sizes, device)
multi_level_anchors = self.prior_generator.grid_priors(
featmap_sizes, device=device)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]

# for each image, we compute valid flags of multi level anchors
Expand Down
4 changes: 2 additions & 2 deletions mmdet/models/dense_heads/atss_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def _init_layers(self):
3,
padding=1)
self.atss_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
self.feat_channels, self.num_base_priors * 4, 3, padding=1)
self.atss_centerness = nn.Conv2d(
self.feat_channels, self.num_anchors * 1, 3, padding=1)
self.feat_channels, self.num_base_priors * 1, 3, padding=1)
self.scales = nn.ModuleList(
[Scale(1.0) for _ in self.prior_generator.strides])

Expand Down
63 changes: 37 additions & 26 deletions mmdet/models/dense_heads/autoassign_head.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import bias_init_with_prob, normal_init
from mmcv.runner import force_fp32

from mmdet.core import distance2bbox, multi_apply
from mmdet.core import multi_apply
from mmdet.core.anchor.point_generator import MlvlPointGenerator
from mmdet.core.bbox import bbox_overlaps
from mmdet.models import HEADS
Expand Down Expand Up @@ -174,22 +176,6 @@ def init_weights(self):
normal_init(self.conv_cls, std=0.01, bias=bias_cls)
normal_init(self.conv_reg, std=0.01, bias=4.0)

def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Almost the same as the implementation in fcos, we remove half stride
offset to align with the original implementation."""

y, x = super(FCOSHead,
self)._get_points_single(featmap_size, stride, dtype,
device)
points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
dim=-1)
return points

def forward_single(self, x, scale, stride):
"""Forward features of a single scale level.
Expand Down Expand Up @@ -349,8 +335,10 @@ def loss(self,
assert len(cls_scores) == len(bbox_preds) == len(objectnesses)
all_num_gt = sum([len(item) for item in gt_bboxes])
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
all_level_points = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device)
inside_gt_bbox_mask_list, bbox_targets_list = self.get_targets(
all_level_points, gt_bboxes)

Expand All @@ -364,7 +352,6 @@ def loss(self,
center_prior_weight_list.append(center_prior_weight)
temp_inside_gt_bbox_mask_list.append(inside_gt_bbox_mask)
inside_gt_bbox_mask_list = temp_inside_gt_bbox_mask_list

mlvl_points = torch.cat(all_level_points, dim=0)
bbox_preds = levels_to_images(bbox_preds)
cls_scores = levels_to_images(cls_scores)
Expand All @@ -374,17 +361,18 @@ def loss(self,
ious_list = []
num_points = len(mlvl_points)

for bbox_pred, gt_bboxe, inside_gt_bbox_mask in zip(
for bbox_pred, encoded_targets, inside_gt_bbox_mask in zip(
bbox_preds, bbox_targets_list, inside_gt_bbox_mask_list):
temp_num_gt = gt_bboxe.size(1)
temp_num_gt = encoded_targets.size(1)
expand_mlvl_points = mlvl_points[:, None, :].expand(
num_points, temp_num_gt, 2).reshape(-1, 2)
gt_bboxe = gt_bboxe.reshape(-1, 4)
encoded_targets = encoded_targets.reshape(-1, 4)
expand_bbox_pred = bbox_pred[:, None, :].expand(
num_points, temp_num_gt, 4).reshape(-1, 4)
decoded_bbox_preds = distance2bbox(expand_mlvl_points,
expand_bbox_pred)
decoded_target_preds = distance2bbox(expand_mlvl_points, gt_bboxe)
decoded_bbox_preds = self.bbox_coder.decode(
expand_mlvl_points, expand_bbox_pred)
decoded_target_preds = self.bbox_coder.decode(
expand_mlvl_points, encoded_targets)
with torch.no_grad():
ious = bbox_overlaps(
decoded_bbox_preds, decoded_target_preds, is_aligned=True)
Expand Down Expand Up @@ -511,3 +499,26 @@ def _get_target_single(self, gt_bboxes, points):
dtype=torch.bool)

return inside_gt_bbox_mask, bbox_targets

def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Almost the same as the implementation in fcos, we remove half stride
offset to align with the original implementation.
This function will be deprecated soon.
"""
warnings.warn(
'`_get_points_single` in `AutoAssignHead` will be '
'deprecated soon, we support a multi level point generator now'
'you can get points of a single level feature map '
'with `self.prior_generator.single_level_grid_priors` ')
y, x = super(FCOSHead,
self)._get_points_single(featmap_size, stride, dtype,
device)
points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
dim=-1)
return points
28 changes: 17 additions & 11 deletions mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def get_bboxes(self,

featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, device=cls_scores[0].device)
featmap_sizes,
dtype=cls_scores[0].device,
device=cls_scores[0].device)

result_list = []

Expand Down Expand Up @@ -118,7 +120,10 @@ def _get_bboxes_single(self,
levels of a single image, each item has shape
(num_priors * 1, H, W).
mlvl_priors (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid, has shape
the priors of a single level in feature pyramid. In all
anchor-based methods, it has shape (num_priors, 4). In
all anchor-free methods, it has shape (num_priors, 2)
when `with_stride=True`, otherwise it still has shape
(num_priors, 4).
img_meta (dict): Image meta info.
cfg (mmcv.Config): Test / postprocessing configuration,
Expand Down Expand Up @@ -181,17 +186,17 @@ def _get_bboxes_single(self,
scores = cls_score.softmax(-1)[:, :-1]

# After https://github.com/open-mmlab/mmdetection/pull/6268/,
# this operation keeps fewer bboxes under the same `nms_pre`,
# there is no difference in performance for most models, if you
# find a slight drop in performance, You can set a larger
# this operation keeps fewer bboxes under the same `nms_pre`.
# There is no difference in performance for most models. If you
# find a slight drop in performance, you can set a larger
# `nms_pre` than before.
results = filter_scores_and_topk(
scores, cfg.score_thr, nms_pre,
dict(bbox_pred=bbox_pred, priors=priors))
scores, labels, keep_idxs, filter_results = results
scores, labels, keep_idxs, filtered_results = results

bbox_pred = filter_results['bbox_pred']
priors = filter_results['priors']
bbox_pred = filtered_results['bbox_pred']
priors = filtered_results['priors']

if with_score_factors:
score_factor = score_factor[keep_idxs]
Expand Down Expand Up @@ -380,9 +385,10 @@ def onnx_export(self,
num_levels = len(cls_scores)

featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(featmap_sizes,
bbox_preds[0].dtype,
bbox_preds[0].device)
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device)

mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
Expand Down
41 changes: 27 additions & 14 deletions mmdet/models/dense_heads/fcos_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -193,8 +195,10 @@ def loss(self,
"""
assert len(cls_scores) == len(bbox_preds) == len(centernesses)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
all_level_points = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device)
labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes,
gt_labels)

Expand Down Expand Up @@ -261,18 +265,6 @@ def loss(self,
loss_bbox=loss_bbox,
loss_centerness=loss_centerness)

def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Get points according to feature map sizes."""
y, x = super()._get_points_single(featmap_size, stride, dtype, device)
points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
dim=-1) + stride // 2
return points

def get_targets(self, points, gt_bboxes_list, gt_labels_list):
"""Compute regression, classification and centerness targets for points
in multiple images.
Expand Down Expand Up @@ -438,3 +430,24 @@ def centerness_target(self, pos_bbox_targets):
left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
return torch.sqrt(centerness_targets)

def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Get points according to feature map size.
This function will be deprecated soon.
"""
warnings.warn(
'`_get_points_single` in `FCOSHead` will be '
'deprecated soon, we support a multi level point generator now'
'you can get points of a single level feature map '
'with `self.prior_generator.single_level_grid_priors` ')

y, x = super()._get_points_single(featmap_size, stride, dtype, device)
points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
dim=-1) + stride // 2
return points
Loading

0 comments on commit 82c4e77

Please sign in to comment.