Skip to content

Commit

Permalink
[Feature]SOLO: Segmenting Objects by Locations (open-mmlab#5832)
Browse files Browse the repository at this point in the history
* add SOLO

* add decoupled SOLO

* update decoupled SOLO

* fix linting errors

* format config filename, config content, loss names, norm_cfg

* fix linting errors

* fix matrix_nms and configs

* Add unit tests for SOLO head

* add diceloss

* support mmdet-v2+

* add decopledhead

* clean Chinese comments

* update SOLO

* fix

* delet debug files

* update solo config

* fix bug

* [Fix]: fix some params cannot get grad

* [fix] make sure params can get grad

* init commit for resutls

* add results and instance results

* add docstr

* add more unitets

* add more unitets

* add more unitets

* add more unintest

* add unitet for instance results

* add example

* add meta_info_keys results_keys

* add modified from

* fix unitets

* fix typo

* add instance seg releated base

* forward train for solo

* fix simpletest

* add docstr

* convert to tensor at begin

* refactor yolact traing

* refactor yolact test

* fix test of yolact

* fix empty det of yolact

* fix return tuple

* add format_results

* add testfor formatr

* solo

* add unitest for format_results

* add unitest

* solo

* remove yolact relatede modification

* fix zero bbox

* fix score size

* fix desolo head

* update solo head

* fix error

* rename some attribute

* rename some attribute

* rename decouple

* add doc

* format loss

* reconer decople

* add doc

* fix test

* fix test

* fix doc

* remove points nms

* refactor the post process

* refactor post process of decaouple

* refactor base

* refactor get_target single

* refactor the training of decouple

* refactor test of decouple

* refactor dice loss

* refactor dice

* change to format a dict

* support detection results in test.py

* add base one-stage segmentor

* fix doc

* add onnx export

* add solo config

* add dice loss test unit

* add solo_head test unit

* add more detailed comments

* resolve commnets

* add test unit

* update docstrings and move center of mass to core.utils

* add center of mass test unit

* resolve comments

* resolve commets

* fix rle encode

* fix results

* fix results

* abstract dice loss

* update docstring

* add EPS

* add center of mass test unit

* add eps parameter

* add vis

* add nms test unit

* configs/

add configs

* add desolo light config file

* support desolo light head

* add desolo light config

* add matrix_nms test unit

* fix matrix_nms test unit

* update matrix doc string

* fix error

* fix logic error

* fix logic error

* add comment in test unit

* move has_acted to initialization

* update solo readme

* rename

* revert test

* fix import in example

* fix unitest

* add more uintest

* add more unites

* add more unitest

* rename meta to meta_info

* fix docstr

* fix foc

* fix doc

* add format_results

* fix format results

* fix some default value and function name

* fix desolo light head error

* fix doc and move isntancedata to a new file

* fix typo

* fix unitest in torch 13

* update matrix nms docstring

* fix hard code

* add vis

* add vis

* fix lint

* fix doc

* fix doc

* fix vis

* fix vis

* fix vis

* fix forwardummy doc

* fix doc

* fix comment

* fix doc

* fix order of argument

* add base one-stage segmentor

* fix config files

* fix doc

* fix doc

* support solo

* fix error

* support solo

* rename cls_score

* support solo

* update model zoo

* update docstring

* update docstring

Co-authored-by: WXinlong <[email protected]>
Co-authored-by: zhangshilong <[email protected]>
  • Loading branch information
3 people authored Sep 28, 2021
1 parent b0cd401 commit 2294bad
Show file tree
Hide file tree
Showing 23 changed files with 2,299 additions and 8 deletions.
42 changes: 42 additions & 0 deletions configs/solo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SOLO: Segmenting Objects by Locations

## Introduction

```
@inproceedings{wang2020solo,
title = {{SOLO}: Segmenting Objects by Locations},
author = {Wang, Xinlong and Kong, Tao and Shen, Chunhua and Jiang, Yuning and Li, Lei},
booktitle = {Proc. Eur. Conf. Computer Vision (ECCV)},
year = {2020}
}
```

## Results and Models

### SOLO

| Backbone | Style | MS train | Lr schd | Mem (GB) | Inf time (fps) | mask AP | Download |
|:---------:|:-------:|:--------:|:-------:|:--------:|:--------------:|:------:|:--------:|
| R-50 | pytorch | N | 1x | 8.0 | 14.0 | 33.1 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_1x_coco/solo_r50_fpn_1x_coco_20210821_035055-2290a6b8.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_1x_coco/solo_r50_fpn_1x_coco_20210821_035055.log.json) |
| R-50 | pytorch | Y | 3x | 7.4 | 14.0 | 35.9 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_3x_coco/solo_r50_fpn_3x_coco_20210901_012353-11d224d7.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_3x_coco/solo_r50_fpn_3x_coco_20210901_012353.log.json) |

### Decoupled SOLO

| Backbone | Style | MS train | Lr schd | Mem (GB) | Inf time (fps) | mask AP | Download |
|:---------:|:-------:|:--------:|:-------:|:--------:|:--------------:|:-------:|:--------:|
| R-50 | pytorch | N | 1x | 7.8 | 12.5 | 33.9 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_1x_coco/decoupled_solo_r50_fpn_1x_coco_20210820_233348-6337c589.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_1x_coco/decoupled_solo_r50_fpn_1x_coco_20210820_233348.log.json) |
| R-50 | pytorch | Y | 3x | 7.9 | 12.5 | 36.7 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_3x_coco/decoupled_solo_r50_fpn_3x_coco_20210821_042504-7b3301ec.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_3x_coco/decoupled_solo_r50_fpn_3x_coco_20210821_042504.log.json) |

- Decoupled SOLO has a decoupled head which is different from SOLO head.
Decoupled SOLO serves as an efficient and equivalent variant in accuracy
of SOLO. Please refer to the corresponding config files for details.

### Decoupled Light SOLO

| Backbone | Style | MS train | Lr schd | Mem (GB) | Inf time (fps) | mask AP | Download |
|:---------:|:-------:|:--------:|:-------:|:--------:|:--------------:|:------:|:--------:|
| R-50 | pytorch | Y | 3x | 2.2 | 31.2 | 32.9 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_light_r50_fpn_3x_coco/decoupled_solo_light_r50_fpn_3x_coco_20210906_142703-e70e226f.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_light_r50_fpn_3x_coco/decoupled_solo_light_r50_fpn_3x_coco_20210906_142703.log.json) |

- Decoupled Light SOLO using decoupled structure similar to Decoupled
SOLO head, with light-weight head and smaller input size, Please refer
to the corresponding config files for details.
63 changes: 63 additions & 0 deletions configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
_base_ = './decoupled_solo_r50_fpn_3x_coco.py'

# model settings
model = dict(
mask_head=dict(
type='DecoupledSOLOLightHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)),
pos_scale=0.2,
num_grids=[40, 36, 24, 16, 12],
cls_down_index=0,
loss_mask=dict(
type='DiceLoss', use_sigmoid=True, activate=False,
loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)))

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(852, 512), (852, 480), (852, 448), (852, 416), (852, 384),
(852, 352)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(852, 512),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]

data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
28 changes: 28 additions & 0 deletions configs/solo/decoupled_solo_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
_base_ = [
'./solo_r50_fpn_1x_coco.py',
]
# model settings
model = dict(
mask_head=dict(
type='DecoupledSOLOHead',
num_classes=80,
in_channels=256,
stacked_convs=7,
feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
pos_scale=0.2,
num_grids=[40, 36, 24, 16, 12],
cls_down_index=0,
loss_mask=dict(
type='DiceLoss', use_sigmoid=True, activate=False,
loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)))

optimizer = dict(type='SGD', lr=0.01)
25 changes: 25 additions & 0 deletions configs/solo/decoupled_solo_r50_fpn_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_base_ = './solo_r50_fpn_3x_coco.py'

# model settings
model = dict(
mask_head=dict(
type='DecoupledSOLOHead',
num_classes=80,
in_channels=256,
stacked_convs=7,
feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
pos_scale=0.2,
num_grids=[40, 36, 24, 16, 12],
cls_down_index=0,
loss_mask=dict(
type='DiceLoss', use_sigmoid=True, activate=False,
loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)))
115 changes: 115 additions & 0 deletions configs/solo/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
Collections:
- Name: SOLO
Metadata:
Training Data: COCO
Training Techniques:
- SGD with Momentum
- Weight Decay
Training Resources: 8x V100 GPUs
Architecture:
- FPN
- Convolution
- ResNet
Paper: https://arxiv.org/abs/1912.04488
README: configs/solo/README.md

Models:
- Name: decoupled_solo_r50_fpn_1x_coco
In Collection: SOLO
Config: configs/solo/decoupled_solo_r50_fpn_1x_coco.py
Metadata:
Training Memory (GB): 7.8
Epochs: 12
inference time (ms/im):
- value: 116.4
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1333, 800)
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 33.9
Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_1x_coco/decoupled_solo_r50_fpn_1x_coco_20210820_233348-6337c589.pth

- Name: decoupled_solo_r50_fpn_3x_coco
In Collection: SOLO
Config: configs/solo/decoupled_solo_r50_fpn_3x_coco.py
Metadata:
Training Memory (GB): 7.9
Epochs: 36
inference time (ms/im):
- value: 117.2
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1333, 800)
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 36.7
Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_3x_coco/decoupled_solo_r50_fpn_3x_coco_20210821_042504-7b3301ec.pth

- Name: decoupled_solo_light_r50_fpn_3x_coco
In Collection: SOLO
Config: configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py
Metadata:
Training Memory (GB): 2.2
Epochs: 36
inference time (ms/im):
- value: 35.0
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (852, 512)
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 32.9
Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_light_r50_fpn_3x_coco/decoupled_solo_light_r50_fpn_3x_coco_20210906_142703-e70e226f.pth

- Name: solo_r50_fpn_3x_coco
In Collection: SOLO
Config: configs/solo/solo_r50_fpn_3x_coco.py
Metadata:
Training Memory (GB): 7.4
Epochs: 36
inference time (ms/im):
- value: 94.2
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1333, 800)
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 35.9
Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_3x_coco/solo_r50_fpn_3x_coco_20210901_012353-11d224d7.pth

- Name: solo_r50_fpn_1x_coco
In Collection: SOLO
Config: configs/solo/solo_r50_fpn_1x_coco.py
Metadata:
Training Memory (GB): 8.0
Epochs: 12
inference time (ms/im):
- value: 95.1
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1333, 800)
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 33.1
Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_1x_coco/solo_r50_fpn_1x_coco_20210821_035055-2290a6b8.pth
53 changes: 53 additions & 0 deletions configs/solo/solo_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
_base_ = [
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

# model settings
model = dict(
type='SOLO',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
mask_head=dict(
type='SOLOHead',
num_classes=80,
in_channels=256,
stacked_convs=7,
feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
pos_scale=0.2,
num_grids=[40, 36, 24, 16, 12],
cls_down_index=0,
loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),
# model training and testing settings
test_cfg=dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
filter_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100))

# optimizer
optimizer = dict(type='SGD', lr=0.01)
28 changes: 28 additions & 0 deletions configs/solo/solo_r50_fpn_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
_base_ = './solo_r50_fpn_1x_coco.py'

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(1333, 800), (1333, 768), (1333, 736), (1333, 704),
(1333, 672), (1333, 640)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(train=dict(pipeline=train_pipeline))

lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)
8 changes: 8 additions & 0 deletions docs/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,14 @@ Please refer to [CenterNet](https://github.com/open-mmlab/mmdetection/blob/maste

Please refer to [YOLOX](https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox) for details.

### PVT

Please refer to [PVT](https://github.com/open-mmlab/mmdetection/blob/master/configs/pvt) for details.

### SOLO

Please refer to [SOLO](https://github.com/open-mmlab/mmdetection/blob/master/configs/solo) for details.

### Other datasets

We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).
Expand Down
3 changes: 2 additions & 1 deletion mmdet/core/post_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bbox_nms import fast_nms, multiclass_nms
from .matrix_nms import mask_matrix_nms
from .merge_augs import (merge_aug_bboxes, merge_aug_masks,
merge_aug_proposals, merge_aug_scores)

__all__ = [
'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
'merge_aug_scores', 'merge_aug_masks', 'fast_nms'
'merge_aug_scores', 'merge_aug_masks', 'mask_matrix_nms', 'fast_nms'
]
Loading

0 comments on commit 2294bad

Please sign in to comment.