Skip to content

Commit

Permalink
Refactor anchor_generator and point_generator (open-mmlab#5349)
Browse files Browse the repository at this point in the history
* add sparse priors

* add mlvlpointsgenerator

* revert __init__ of core

* refactor reppoints

* delete label channal

* add docstr

* fix typo

* fix args

* fix typo

* fix doc

* fix stride_h

* add offset

* add offset

* fix docstr

* new interface of single_proir

* fix device

* add unitest

* add cuda unitest

* add more cuda unintest

* fix reppoints

* fix device

* add unintest for ssd and yolo and rename prior_idxs

* add docstr for MlvlPointGenerator

* add space

* add num_base_priors
  • Loading branch information
jshilong authored Jun 24, 2021
1 parent 269bb9e commit e91da70
Show file tree
Hide file tree
Showing 6 changed files with 626 additions and 49 deletions.
8 changes: 5 additions & 3 deletions mmdet/core/anchor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
YOLOAnchorGenerator)
from .builder import ANCHOR_GENERATORS, build_anchor_generator
from .point_generator import PointGenerator
from .builder import (ANCHOR_GENERATORS, PRIOR_GENERATORS,
build_anchor_generator, build_prior_generator)
from .point_generator import MlvlPointGenerator, PointGenerator
from .utils import anchor_inside_flags, calc_region, images_to_levels

__all__ = [
'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
'PointGenerator', 'images_to_levels', 'calc_region',
'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator'
'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator',
'build_prior_generator', 'PRIOR_GENERATORS', 'MlvlPointGenerator'
]
131 changes: 121 additions & 10 deletions mmdet/core/anchor/anchor_generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import warnings

import mmcv
import numpy as np
import torch
from torch.nn.modules.utils import _pair

from .builder import ANCHOR_GENERATORS
from .builder import PRIOR_GENERATORS


@ANCHOR_GENERATORS.register_module()
@PRIOR_GENERATORS.register_module()
class AnchorGenerator:
"""Standard anchor generator for 2D anchor-based detectors.
Expand Down Expand Up @@ -68,7 +70,7 @@ def __init__(self,
# check center and center_offset
if center_offset != 0:
assert centers is None, 'center cannot be set when center_offset' \
f'!=0, {centers} is given.'
f'!=0, {centers} is given.'
if not (0 <= center_offset <= 1):
raise ValueError('center_offset should be in range [0, 1], '
f'{center_offset} is given.')
Expand All @@ -87,7 +89,7 @@ def __init__(self,

# calculate scales of anchors
assert ((octave_base_scale is not None
and scales_per_octave is not None) ^ (scales is not None)), \
and scales_per_octave is not None) ^ (scales is not None)), \
'scales and octave_base_scale with scales_per_octave cannot' \
' be set at the same time'
if scales is not None:
Expand All @@ -112,6 +114,12 @@ def __init__(self,
@property
def num_base_anchors(self):
"""list[int]: total number of base anchors in a feature grid"""
return self.num_base_priors

@property
def num_base_priors(self):
"""list[int]: The number of priors (anchors) at a point
on the feature grid"""
return [base_anchors.size(0) for base_anchors in self.base_anchors]

@property
Expand Down Expand Up @@ -204,6 +212,99 @@ def _meshgrid(self, x, y, row_major=True):
else:
return yy, xx

def grid_priors(self, featmap_sizes, device='cuda'):
"""Generate grid anchors in multiple feature levels.
Args:
featmap_sizes (list[tuple]): List of feature map sizes in
multiple feature levels.
device (str): The device where the anchors will be put on.
Return:
list[torch.Tensor]: Anchors in multiple feature levels. \
The sizes of each tensor should be [N, 4], where \
N = width * height * num_base_anchors, width and height \
are the sizes of the corresponding feature level, \
num_base_anchors is the number of anchors for that level.
"""
assert self.num_levels == len(featmap_sizes)
multi_level_anchors = []
for i in range(self.num_levels):
anchors = self.single_level_grid_priors(
featmap_sizes[i], level_idx=i, device=device)
multi_level_anchors.append(anchors)
return multi_level_anchors

def single_level_grid_priors(self, featmap_size, level_idx, device='cuda'):
"""Generate grid anchors of a single level.
Note:
This function is usually called by method ``self.grid_priors``.
Args:
featmap_size (tuple[int]): Size of the feature maps.
level_idx (int): The index of corresponding feature map level.
device (str, optional): The device the tensor will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: Anchors in the overall feature maps.
"""

base_anchors = self.base_anchors[level_idx].to(device)
feat_h, feat_w = featmap_size
stride_w, stride_h = self.strides[level_idx]
shift_x = torch.arange(0, feat_w, device=device) * stride_w
shift_y = torch.arange(0, feat_h, device=device) * stride_h

shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
shifts = shifts.type_as(base_anchors)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)

all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.view(-1, 4)
# first A rows correspond to A anchors of (0, 0) in feature map,
# then (0, 1), (0, 2), ...
return all_anchors

def sparse_priors(self,
prior_idxs,
featmap_size,
level_idx,
dtype=torch.float32,
device='cuda'):
"""Generate sparse anchors according to the ``prior_idxs``.
Args:
prior_idxs (Tensor): The index of corresponding anchors
in the feature map.
featmap_size (tuple[int]): feature map size arrange as (h, w).
level_idx (int): The level index of corresponding feature
map.
dtype (obj:`torch.dtype`): Date type of points.Defaults to
``torch.float32``.
device (obj:`torch.device`): The device where the points is
located.
Returns:
Tensor: Anchor with shape (N, 4), N should be equal to
the length of ``prior_idxs``.
"""

height, width = featmap_size
num_base_anchors = self.num_base_anchors[level_idx]
base_anchor_id = prior_idxs % num_base_anchors
x = (prior_idxs //
num_base_anchors) % width * self.strides[level_idx][0]
y = (prior_idxs // width //
num_base_anchors) % height * self.strides[level_idx][1]
priors = torch.stack([x, y, x, y], 1).to(dtype).to(device) + \
self.base_anchors[level_idx][base_anchor_id, :].to(device)

return priors

def grid_anchors(self, featmap_sizes, device='cuda'):
"""Generate grid anchors in multiple feature levels.
Expand All @@ -219,6 +320,9 @@ def grid_anchors(self, featmap_sizes, device='cuda'):
are the sizes of the corresponding feature level, \
num_base_anchors is the number of anchors for that level.
"""
warnings.warn('``grid_anchors`` would be deprecated soon. '
'Please use ``grid_priors`` ')

assert self.num_levels == len(featmap_sizes)
multi_level_anchors = []
for i in range(self.num_levels):
Expand Down Expand Up @@ -251,7 +355,13 @@ def single_level_grid_anchors(self,
Returns:
torch.Tensor: Anchors in the overall feature maps.
"""
# keep as Tensor, so that we can covert to ONNX correctly

warnings.warn(
'``single_level_grid_anchors`` would be deprecated soon. '
'Please use ``single_level_grid_priors`` ')

# keep featmap_size as Tensor instead of int, so that we
# can covert to ONNX correctly
feat_h, feat_w = featmap_size
shift_x = torch.arange(0, feat_w, device=device) * stride[0]
shift_y = torch.arange(0, feat_h, device=device) * stride[1]
Expand Down Expand Up @@ -304,7 +414,8 @@ def single_level_valid_flags(self,
"""Generate the valid flags of anchor in a single feature map.
Args:
featmap_size (tuple[int]): The size of feature maps.
featmap_size (tuple[int]): The size of feature maps, arrange
as (h, w).
valid_size (tuple[int]): The valid size of the feature maps.
num_base_anchors (int): The number of base anchors.
device (str, optional): Device where the flags will be put on.
Expand Down Expand Up @@ -346,7 +457,7 @@ def __repr__(self):
return repr_str


@ANCHOR_GENERATORS.register_module()
@PRIOR_GENERATORS.register_module()
class SSDAnchorGenerator(AnchorGenerator):
"""Anchor generator for SSD.
Expand Down Expand Up @@ -470,7 +581,7 @@ def __repr__(self):
return repr_str


@ANCHOR_GENERATORS.register_module()
@PRIOR_GENERATORS.register_module()
class LegacyAnchorGenerator(AnchorGenerator):
"""Legacy anchor generator used in MMDetection V1.x.
Expand Down Expand Up @@ -569,7 +680,7 @@ def gen_single_level_base_anchors(self,
return base_anchors


@ANCHOR_GENERATORS.register_module()
@PRIOR_GENERATORS.register_module()
class LegacySSDAnchorGenerator(SSDAnchorGenerator, LegacyAnchorGenerator):
"""Legacy anchor generator used in MMDetection V1.x.
Expand All @@ -591,7 +702,7 @@ def __init__(self,
self.base_anchors = self.gen_base_anchors()


@ANCHOR_GENERATORS.register_module()
@PRIOR_GENERATORS.register_module()
class YOLOAnchorGenerator(AnchorGenerator):
"""Anchor generator for YOLO.
Expand Down
15 changes: 13 additions & 2 deletions mmdet/core/anchor/builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import warnings

from mmcv.utils import Registry, build_from_cfg

ANCHOR_GENERATORS = Registry('Anchor generator')
PRIOR_GENERATORS = Registry('Generator for anchors and points')

ANCHOR_GENERATORS = PRIOR_GENERATORS


def build_prior_generator(cfg, default_args=None):
return build_from_cfg(cfg, PRIOR_GENERATORS, default_args)


def build_anchor_generator(cfg, default_args=None):
return build_from_cfg(cfg, ANCHOR_GENERATORS, default_args)
warnings.warn(
'``build_anchor_generator`` would be deprecated soon, please use '
'``build_prior_generator`` ')
return build_prior_generator(cfg, default_args=default_args)
Loading

0 comments on commit e91da70

Please sign in to comment.