Skip to content

Commit

Permalink
[Enhance] Unify the interface of stuff head and panoptic head (open-m…
Browse files Browse the repository at this point in the history
  • Loading branch information
AronLin authored Oct 20, 2021
1 parent 4e24a86 commit 9874180
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 21 deletions.
3 changes: 2 additions & 1 deletion configs/panoptic_fpn/panoptic_fpn_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
type='PanopticFPN',
semantic_head=dict(
type='PanopticFPNHead',
num_classes=54,
num_things_classes=80,
num_stuff_classes=53,
in_channels=256,
inner_channels=128,
start_level=0,
Expand Down
77 changes: 57 additions & 20 deletions mmdet/models/seg_heads/panoptic_fpn_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
from mmcv.runner import ModuleList
Expand All @@ -12,18 +14,33 @@
class PanopticFPNHead(BaseSemanticHead):
"""PanopticFPNHead used in Panoptic FPN.
In this head, the number of output channels is ``num_stuff_classes
+ 1``, including all stuff classes and one thing class. The stuff
classes will be reset from ``0`` to ``num_stuff_classes - 1``, the
thing classes will be merged to ``num_stuff_classes``-th channel.
Arg:
num_things_classes (int): Number of thing classes. Default: 80.
num_stuff_classes (int): Number of stuff classes. Default: 53.
num_classes (int): Number of classes, including all stuff
classes and one thing class.
classes and one thing class. This argument is deprecated,
please use ``num_things_classes`` and ``num_stuff_classes``.
The module will automatically infer the num_classes by
``num_stuff_classes + 1``.
in_channels (int): Number of channels in the input feature
map.
inner_channels (int): Number of channels in inner features.
start_level (int): The start level of the input features
used in PanopticFPN.
end_level (int): The end level of the used features, the
`end_level`-th layer will not be used.
fg_range (tuple): Range of the foreground classes.
bg_range (tuple): Range of the background classes.
``end_level``-th layer will not be used.
fg_range (tuple): Range of the foreground classes. It starts
from ``0`` to ``num_things_classes-1``. Deprecated, please use
``num_things_classes`` directly.
bg_range (tuple): Range of the background classes. It starts
from ``num_things_classes`` to ``num_things_classes +
num_stuff_classes - 1``. Deprecated, please use
``num_stuff_classes`` and ``num_things_classes`` directly.
conv_cfg (dict): Dictionary to construct and config
conv layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Expand All @@ -33,24 +50,42 @@ class PanopticFPNHead(BaseSemanticHead):
"""

def __init__(self,
num_classes,
num_things_classes=80,
num_stuff_classes=53,
num_classes=None,
in_channels=256,
inner_channels=128,
start_level=0,
end_level=4,
fg_range=(0, 79),
bg_range=(80, 132),
fg_range=None,
bg_range=None,
conv_cfg=None,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
init_cfg=None,
loss_seg=dict(
type='CrossEntropyLoss', ignore_index=-1,
loss_weight=1.0)):
super(PanopticFPNHead, self).__init__(num_classes, init_cfg, loss_seg)
self.fg_range = fg_range
self.bg_range = bg_range
self.fg_nums = self.fg_range[1] - self.fg_range[0] + 1
self.bg_nums = self.bg_range[1] - self.bg_range[0] + 1
if num_classes is not None:
warnings.warn(
'`num_classes` is deprecated now, please set '
'`num_stuff_classes` directly, the `num_classes` will be '
'set to `num_stuff_classes + 1`')
# num_classes = num_stuff_classes + 1 for PanopticFPN.
assert num_classes == num_stuff_classes + 1
super(PanopticFPNHead, self).__init__(num_stuff_classes + 1, init_cfg,
loss_seg)
self.num_things_classes = num_things_classes
self.num_stuff_classes = num_stuff_classes
if fg_range is not None and bg_range is not None:
self.fg_range = fg_range
self.bg_range = bg_range
self.num_things_classes = fg_range[1] - fg_range[0] + 1
self.num_stuff_classes = bg_range[1] - bg_range[0] + 1
warnings.warn(
'`fg_range` and `bg_range` are deprecated now, '
f'please use `num_things_classes`={self.num_things_classes} '
f'and `num_stuff_classes`={self.num_stuff_classes} instead.')

# Used feature layers are [start_level, end_level)
self.start_level = start_level
self.end_level = end_level
Expand All @@ -68,25 +103,27 @@ def __init__(self,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
))
self.conv_logits = nn.Conv2d(inner_channels, num_classes, 1)
self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1)

def _set_things_to_void(self, gt_semantic_seg):
"""Merge thing classes to one class.
In PanopticFPN, the background labels will be reset from `0` to
`self.bg_nums-1`, the foreground labels will merged to `self.bg_nums`.
`self.num_stuff_classes-1`, the foreground labels will be merged to
`self.num_stuff_classes`-th channel.
"""
gt_semantic_seg = gt_semantic_seg.int()
fg_mask = (gt_semantic_seg >= self.fg_range[0]) * (
gt_semantic_seg <= self.fg_range[1])
bg_mask = (gt_semantic_seg >= self.bg_range[0]) * (
gt_semantic_seg <= self.bg_range[1])
fg_mask = gt_semantic_seg < self.num_things_classes
bg_mask = (gt_semantic_seg >= self.num_things_classes) * (
gt_semantic_seg < self.num_things_classes + self.num_stuff_classes)

new_gt_seg = torch.clone(gt_semantic_seg)
new_gt_seg = torch.where(bg_mask, gt_semantic_seg - self.fg_nums,
new_gt_seg = torch.where(bg_mask,
gt_semantic_seg - self.num_things_classes,
new_gt_seg)
new_gt_seg = torch.where(fg_mask,
fg_mask.int() * self.bg_nums, new_gt_seg)
fg_mask.int() * self.num_stuff_classes,
new_gt_seg)
return new_gt_seg

def loss(self, seg_preds, gt_semantic_seg):
Expand Down

0 comments on commit 9874180

Please sign in to comment.