forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] support mask2former for vis (open-mmlab#10245)
- Loading branch information
1 parent
858902f
commit 3924a46
Showing
33 changed files
with
3,607 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# dataset settings | ||
train_pipeline = [ | ||
dict( | ||
type='UniformRefFrameSample', | ||
num_ref_imgs=1, | ||
frame_range=100, | ||
filter_key_img=True), | ||
dict( | ||
type='TransformBroadcaster', | ||
share_random_params=True, | ||
transforms=[ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadTrackAnnotations', with_mask=True), | ||
dict(type='Resize', scale=(640, 360), keep_ratio=True), | ||
dict(type='RandomFlip', prob=0.5), | ||
]), | ||
dict(type='PackTrackInputs') | ||
] | ||
|
||
test_pipeline = [ | ||
dict( | ||
type='TransformBroadcaster', | ||
transforms=[ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Resize', scale=(640, 360), keep_ratio=True), | ||
dict(type='LoadTrackAnnotations', with_mask=True), | ||
]), | ||
dict(type='PackTrackInputs') | ||
] | ||
|
||
dataset_type = 'YouTubeVISDataset' | ||
data_root = 'data/youtube_vis_2019/' | ||
dataset_version = data_root[-5:-1] # 2019 or 2021 | ||
# dataloader | ||
train_dataloader = dict( | ||
batch_size=2, | ||
num_workers=2, | ||
persistent_workers=True, | ||
# MOTChallengeDataset is a video-based dataset, so we don't need | ||
# "AspectRatioBatchSampler" | ||
# batch_sampler=dict(type='AspectRatioBatchSampler'), | ||
# sampler=dict(type='TrackImgSampler'), # image-based sampling | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
batch_sampler=dict(type='TrackAspectRatioBatchSampler'), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
dataset_version=dataset_version, | ||
ann_file='annotations/youtube_vis_2019_train.json', | ||
data_prefix=dict(img_path='train/JPEGImages'), | ||
pipeline=train_pipeline)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
dataset_version=dataset_version, | ||
ann_file='annotations/youtube_vis_2019_valid.json', | ||
data_prefix=dict(img_path='valid/JPEGImages'), | ||
test_mode=True, | ||
pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Mask2Former for Video Instance Segmentation | ||
|
||
## Abstract | ||
|
||
<!-- [ABSTRACT] --> | ||
|
||
We find Mask2Former also achieves state-of-the-art performance on video instance segmentation without modifying the architecture, the loss or even the training pipeline. In this report, we show universal image segmentation architectures trivially generalize to video segmentation by directly predicting 3D segmentation volumes. Specifically, Mask2Former sets a new state-of-the-art of 60.4 AP on YouTubeVIS-2019 and 52.6 AP on YouTubeVIS-2021. We believe Mask2Former is also capable of handling video semantic and panoptic segmentation, given its versatility in image segmentation. We hope this will make state-of-theart video segmentation research more accessible and bring more attention to designing universal image and video segmentation architectures. | ||
|
||
<!-- [IMAGE] --> | ||
|
||
<div align="center"> | ||
<img src="https://user-images.githubusercontent.com/46072190/188271377-164634a5-4d65-4161-8a69-2d0eaf2791f8.png"/> | ||
</div> | ||
|
||
## Citation | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
```latex | ||
@inproceedings{cheng2021mask2former, | ||
title={Masked-attention Mask Transformer for Universal Image Segmentation}, | ||
author={Bowen Cheng and Ishan Misra and Alexander G. Schwing and Alexander Kirillov and Rohit Girdhar}, | ||
journal={CVPR}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
## Results and models of Mask2Former on YouTube-VIS 2021 validation dataset | ||
|
||
Note: Codalab has closed the evaluation portal of `YouTube-VIS 2019`, so we do not provide the results of `YouTube-VIS 2019` at present. If you want to evaluate the results of `YouTube-VIS 2021`, at present, you can submit the result to the evaluation portal of `YouTube-VIS 2022`. The value of `AP_S` is the result of `YouTube-VIS 2021`. | ||
|
||
| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | AP | Config | Download | | ||
| :----------------------: | :------: | :-----: | :-----: | :------: | :------------: | :--: | :---------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | ||
| Mask2Former | R-50 | pytorch | 8e | 6.0 | - | 41.3 | [config](mask2former_r50_8xb2-8e_youtubevis2021.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mask2former_vis/mask2former_r50_8xb2-8e_youtubevis2021/mask2former_r50_8xb2-8e_youtubevis2021_20230426_131833-5d215283.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/mask2former_vis/mask2former_r50_8xb2-8e_youtubevis2021/mask2former_r50_8xb2-8e_youtubevis2021_20230426_131833.json) | | ||
| Mask2Former | R-101 | pytorch | 8e | 7.5 | - | 42.3 | [config](mask2former_r101_8xb2-8e_youtubevis2021.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mask2former_vis/mask2former_r101_8xb2-8e_youtubevis2021/mask2former_r101_8xb2-8e_youtubevis2021_20220823_092747-8077d115.pth) \| [log](https://download.openmmlab.com/mmtracking/vis/mask2former/mask2former_r101_8xb2-8e_youtubevis2021_20220823_092747.json) | | ||
| Mask2Former(200 queries) | Swin-L | pytorch | 8e | 18.5 | - | 52.3 | [config](mask2former_swin-l-p4-w12-384-in21k_8xb2-8e_youtubevis2021.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mask2former_vis/mask2former_swin-l-p4-w12-384-in21k_8xb2-8e_youtubevis2021/mask2former_swin-l-p4-w12-384-in21k_8xb2-8e_youtubevis2021_20220907_124752-48252603.pth) \| [log](https://download.openmmlab.com/mmtracking/vis/mask2former/mask2former_swin-l-p4-w12-384-in21k_8xb2-8e_youtubevis2021_20220907_124752.json) | | ||
|
||
## Get started | ||
|
||
### 1. Training | ||
|
||
Due to the influence of parameters such as learning rate in default configuration file, we recommend using 8 GPUs for training in order to reproduce accuracy. You can use the following command to start the training. | ||
|
||
```shell | ||
# Training Mask2Former on YouTube-VIS-2021 dataset with following command. | ||
# The number after config file represents the number of GPUs used. Here we use 8 GPUs. | ||
bash tools/dist_train.sh configs/mask2former_vis/mask2former_r50_8xb2-8e_youtubevis202.py 8 | ||
``` | ||
|
||
### 2. Testing and evaluation | ||
|
||
If you want to get the results of the [YouTube-VOS](https://youtube-vos.org/dataset/vis/) val/test set, please use the following command to generate result files that can be used for submission. It will be stored in `./youtube_vis_results.submission_file.zip`, you can modify the saved path in `test_evaluator` of the config. | ||
|
||
```shell | ||
# The number after config file represents the number of GPUs used. | ||
bash tools/dist_test_tracking.sh configs/mask2former_vis/mask2former_r50_8xb2-8e_youtubevis2021.py --checkpoint {CHECKPOINT_PATH} | ||
``` | ||
|
||
### 3.Inference | ||
|
||
Use a single GPU to predict a video and save it as a video. | ||
|
||
```shell | ||
python demo/mot_demo.py demo/demo_mot.mp4 configs/mask2former_vis/mask2former_r50_8xb2-8e_youtubevis2021.py --checkpoint {CHECKPOINT_PATH} --out vis.mp4 | ||
``` |
12 changes: 12 additions & 0 deletions
12
configs/mask2former_vis/mask2former_r101_8xb2-8e_youtubevis2019.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
_base_ = './mask2former_r50_8xb2-8e_youtubevis2019.py' | ||
|
||
model = dict( | ||
backbone=dict( | ||
depth=101, | ||
init_cfg=dict(type='Pretrained', | ||
checkpoint='torchvision://resnet101')), | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint='https://download.openmmlab.com/mmdetection/v3.0/' | ||
'mask2former/mask2former_r101_8xb2-lsj-50e_coco/' | ||
'mask2former_r101_8xb2-lsj-50e_coco_20220426_100250-ecf181e2.pth')) |
12 changes: 12 additions & 0 deletions
12
configs/mask2former_vis/mask2former_r101_8xb2-8e_youtubevis2021.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
_base_ = './mask2former_r50_8xb2-8e_youtubevis2021.py' | ||
|
||
model = dict( | ||
backbone=dict( | ||
depth=101, | ||
init_cfg=dict(type='Pretrained', | ||
checkpoint='torchvision://resnet101')), | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint='https://download.openmmlab.com/mmdetection/v3.0/' | ||
'mask2former/mask2former_r101_8xb2-lsj-50e_coco/' | ||
'mask2former_r101_8xb2-lsj-50e_coco_20220426_100250-ecf181e2.pth')) |
174 changes: 174 additions & 0 deletions
174
configs/mask2former_vis/mask2former_r50_8xb2-8e_youtubevis2019.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
_base_ = ['../_base_/datasets/youtube_vis.py', '../_base_/default_runtime.py'] | ||
|
||
num_classes = 40 | ||
num_frames = 2 | ||
model = dict( | ||
type='Mask2FormerVideo', | ||
data_preprocessor=dict( | ||
type='TrackDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_mask=True, | ||
pad_size_divisor=32), | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=-1, | ||
norm_cfg=dict(type='BN', requires_grad=False), | ||
norm_eval=True, | ||
style='pytorch', | ||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), | ||
track_head=dict( | ||
type='Mask2FormerTrackHead', | ||
in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside | ||
strides=[4, 8, 16, 32], | ||
feat_channels=256, | ||
out_channels=256, | ||
num_classes=num_classes, | ||
num_queries=100, | ||
num_frames=num_frames, | ||
num_transformer_feat_level=3, | ||
pixel_decoder=dict( | ||
type='MSDeformAttnPixelDecoder', | ||
num_outs=3, | ||
norm_cfg=dict(type='GN', num_groups=32), | ||
act_cfg=dict(type='ReLU'), | ||
encoder=dict( # DeformableDetrTransformerEncoder | ||
num_layers=6, | ||
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer | ||
self_attn_cfg=dict( # MultiScaleDeformableAttention | ||
embed_dims=256, | ||
num_heads=8, | ||
num_levels=3, | ||
num_points=4, | ||
im2col_step=128, | ||
dropout=0.0, | ||
batch_first=True), | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=1024, | ||
num_fcs=2, | ||
ffn_drop=0.0, | ||
act_cfg=dict(type='ReLU', inplace=True)))), | ||
positional_encoding=dict(num_feats=128, normalize=True)), | ||
enforce_decoder_input_project=False, | ||
positional_encoding=dict( | ||
type='SinePositionalEncoding3D', num_feats=128, normalize=True), | ||
transformer_decoder=dict( # Mask2FormerTransformerDecoder | ||
return_intermediate=True, | ||
num_layers=9, | ||
layer_cfg=dict( # Mask2FormerTransformerDecoderLayer | ||
self_attn_cfg=dict( # MultiheadAttention | ||
embed_dims=256, | ||
num_heads=8, | ||
dropout=0.0, | ||
batch_first=True), | ||
cross_attn_cfg=dict( # MultiheadAttention | ||
embed_dims=256, | ||
num_heads=8, | ||
dropout=0.0, | ||
batch_first=True), | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=2048, | ||
num_fcs=2, | ||
ffn_drop=0.0, | ||
act_cfg=dict(type='ReLU', inplace=True))), | ||
init_cfg=None), | ||
loss_cls=dict( | ||
type='CrossEntropyLoss', | ||
use_sigmoid=False, | ||
loss_weight=2.0, | ||
reduction='mean', | ||
class_weight=[1.0] * num_classes + [0.1]), | ||
loss_mask=dict( | ||
type='CrossEntropyLoss', | ||
use_sigmoid=True, | ||
reduction='mean', | ||
loss_weight=5.0), | ||
loss_dice=dict( | ||
type='DiceLoss', | ||
use_sigmoid=True, | ||
activate=True, | ||
reduction='mean', | ||
naive_dice=True, | ||
eps=1.0, | ||
loss_weight=5.0), | ||
train_cfg=dict( | ||
num_points=12544, | ||
oversample_ratio=3.0, | ||
importance_sample_ratio=0.75, | ||
assigner=dict( | ||
type='HungarianAssigner', | ||
match_costs=[ | ||
dict(type='ClassificationCost', weight=2.0), | ||
dict( | ||
type='CrossEntropyLossCost', | ||
weight=5.0, | ||
use_sigmoid=True), | ||
dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0) | ||
]), | ||
sampler=dict(type='MaskPseudoSampler'))), | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint='https://download.openmmlab.com/mmdetection/v3.0/' | ||
'mask2former/mask2former_r50_8xb2-lsj-50e_coco/' | ||
'mask2former_r50_8xb2-lsj-50e_coco_20220506_191028-41b088b6.pth')) | ||
|
||
# optimizer | ||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0) | ||
optim_wrapper = dict( | ||
type='OptimWrapper', | ||
optimizer=dict( | ||
type='AdamW', | ||
lr=0.0001, | ||
weight_decay=0.05, | ||
eps=1e-8, | ||
betas=(0.9, 0.999)), | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'backbone': dict(lr_mult=0.1, decay_mult=1.0), | ||
'query_embed': embed_multi, | ||
'query_feat': embed_multi, | ||
'level_embed': embed_multi, | ||
}, | ||
norm_decay_mult=0.0), | ||
clip_grad=dict(max_norm=0.01, norm_type=2)) | ||
|
||
# learning policy | ||
max_iters = 6000 | ||
param_scheduler = dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=max_iters, | ||
by_epoch=False, | ||
milestones=[ | ||
4000, | ||
], | ||
gamma=0.1) | ||
# runtime settings | ||
train_cfg = dict( | ||
type='IterBasedTrainLoop', max_iters=max_iters, val_interval=6001) | ||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') | ||
|
||
vis_backends = [dict(type='LocalVisBackend')] | ||
visualizer = dict( | ||
type='TrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') | ||
|
||
default_hooks = dict( | ||
checkpoint=dict( | ||
type='CheckpointHook', by_epoch=False, save_last=True, interval=2000), | ||
visualization=dict(type='TrackVisualizationHook', draw=False)) | ||
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False) | ||
|
||
# evaluator | ||
val_evaluator = dict( | ||
type='YouTubeVISMetric', | ||
metric='youtube_vis_ap', | ||
outfile_prefix='./youtube_vis_results', | ||
format_only=True) | ||
test_evaluator = val_evaluator |
37 changes: 37 additions & 0 deletions
37
configs/mask2former_vis/mask2former_r50_8xb2-8e_youtubevis2021.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
_base_ = './mask2former_r50_8xb2-8e_youtubevis2019.py' | ||
|
||
dataset_type = 'YouTubeVISDataset' | ||
data_root = 'data/youtube_vis_2021/' | ||
dataset_version = data_root[-5:-1] # 2019 or 2021 | ||
|
||
train_dataloader = dict( | ||
dataset=dict( | ||
data_root=data_root, | ||
dataset_version=dataset_version, | ||
ann_file='annotations/youtube_vis_2021_train.json')) | ||
|
||
val_dataloader = dict( | ||
dataset=dict( | ||
data_root=data_root, | ||
dataset_version=dataset_version, | ||
ann_file='annotations/youtube_vis_2021_valid.json')) | ||
test_dataloader = val_dataloader | ||
|
||
# learning policy | ||
max_iters = 8000 | ||
param_scheduler = dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=max_iters, | ||
by_epoch=False, | ||
milestones=[ | ||
5500, | ||
], | ||
gamma=0.1) | ||
# runtime settings | ||
train_cfg = dict( | ||
type='IterBasedTrainLoop', max_iters=max_iters, val_interval=8001) | ||
|
||
default_hooks = dict( | ||
checkpoint=dict( | ||
type='CheckpointHook', by_epoch=False, save_last=True, interval=500)) |
Oops, something went wrong.