forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] support sort and deepsort (open-mmlab#10240)
Co-authored-by: zhangwenhua <[email protected]>
- Loading branch information
1 parent
63bb822
commit ce40e83
Showing
65 changed files
with
4,209 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# dataset settings | ||
dataset_type = 'CocoDataset' | ||
data_root = 'data/MOT17/' | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile', to_float32=True), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict( | ||
type='RandomResize', | ||
scale=(1088, 1088), | ||
ratio_range=(0.8, 1.2), | ||
keep_ratio=True, | ||
clip_object_border=False), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='RandomCrop', crop_size=(1088, 1088), bbox_clip_border=False), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PackDetInputs') | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Resize', scale=(1088, 1088), keep_ratio=True), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict( | ||
type='PackDetInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
'scale_factor')) | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=2, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
batch_sampler=dict(type='AspectRatioBatchSampler'), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/half-train_cocoformat.json', | ||
data_prefix=dict(img='train/'), | ||
metainfo=dict(classes=('pedestrian', )), | ||
filter_cfg=dict(filter_empty_gt=True, min_size=32), | ||
pipeline=train_pipeline)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/half-val_cocoformat.json', | ||
data_prefix=dict(img='train/'), | ||
metainfo=dict(classes=('pedestrian', )), | ||
test_mode=True, | ||
pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict( | ||
type='CocoMetric', | ||
ann_file=data_root + 'annotations/half-val_cocoformat.json', | ||
metric='bbox', | ||
format_only=False) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# dataset settings | ||
dataset_type = 'ReIDDataset' | ||
data_root = 'data/MOT17/' | ||
|
||
# data pipeline | ||
train_pipeline = [ | ||
dict( | ||
type='TransformBroadcaster', | ||
share_random_params=False, | ||
transforms=[ | ||
dict(type='LoadImageFromFile', to_float32=True), | ||
dict( | ||
type='Resize', | ||
scale=(128, 256), | ||
keep_ratio=False, | ||
clip_object_border=False), | ||
dict(type='RandomFlip', prob=0.5, direction='horizontal'), | ||
]), | ||
dict(type='PackReIDInputs', meta_keys=('flip', 'flip_direction')) | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile', to_float32=True), | ||
dict(type='Resize', scale=(128, 256), keep_ratio=False), | ||
dict(type='PackReIDInputs') | ||
] | ||
|
||
# dataloader | ||
train_dataloader = dict( | ||
batch_size=1, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
triplet_sampler=dict(num_ids=8, ins_per_id=4), | ||
data_prefix=dict(img_path='reid/imgs'), | ||
ann_file='reid/meta/train_80.txt', | ||
pipeline=train_pipeline)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
triplet_sampler=None, | ||
data_prefix=dict(img_path='reid/imgs'), | ||
ann_file='reid/meta/val_20.txt', | ||
pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
# evaluator | ||
val_evaluator = dict(type='ReIDMetrics', metric=['mAP', 'CMC']) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Simple online and realtime tracking with a deep association metric | ||
|
||
## Abstract | ||
|
||
<!-- [ABSTRACT] --> | ||
|
||
Simple Online and Realtime Tracking (SORT) is a pragmatic approach to multiple object tracking with a focus on simple, effective algorithms. In this paper, we integrate appearance information to improve the performance of SORT. Due to this extension we are able to track objects through longer periods of occlusions, effectively reducing the number of identity switches. In spirit of the original framework we place much of the computational complexity into an offline pre-training stage where we learn a deep association metric on a largescale person re-identification dataset. During online application, we establish measurement-to-track associations using nearest neighbor queries in visual appearance space. Experimental evaluation shows that our extensions reduce the number of identity switches by 45%, achieving overall competitive performance at high frame rates. | ||
|
||
<!-- [IMAGE] --> | ||
|
||
<div align="center"> | ||
<img src="https://user-images.githubusercontent.com/26813582/145542023-22950508-b35f-41b6-bc78-33d6a82bc3c3.png"/> | ||
</div> | ||
|
||
## Results and models on MOT17 | ||
|
||
Currently we do not support training ReID models for DeepSORT. | ||
We directly use the ReID model from [Tracktor](https://github.com/phil-bergmann/tracking_wo_bnw). These missed features will be supported in the future. | ||
|
||
| Method | Detector | ReID | Train Set | Test Set | Public | Inf time (fps) | HOTA | MOTA | IDF1 | FP | FN | IDSw. | Config | Download | | ||
| :------: | :----------------: | :--: | :--------: | :------: | :----: | :------------: | :--: | :--: | :--: | :---: | :---: | :---: | :--------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | ||
| DeepSORT | R50-FasterRCNN-FPN | R50 | half-train | half-val | N | 13.8 | 57.0 | 63.7 | 69.5 | 15063 | 40323 | 3276 | [config](deepsort_faster-rcnn_r50_fpn_8xb2-4e_mot17halftrain_test-mot17halfval.py) | [detector](https://download.openmmlab.com/mmtracking/mot/faster_rcnn/faster-rcnn_r50_fpn_4e_mot17-half-64ee2ed4.pth) [reid](https://download.openmmlab.com/mmtracking/mot/reid/tracktor_reid_r50_iter25245-a452f51f.pth) | | ||
|
||
## Get started | ||
|
||
### 1. Training | ||
|
||
We implement DeepSORT with independent detector and ReID models. | ||
Note that, due to the influence of parameters such as learning rate in default configuration file, | ||
we recommend using 8 GPUs for training in order to reproduce accuracy. | ||
|
||
You can train the detector as follows. | ||
|
||
```shell script | ||
# Training Faster R-CNN on mot17-half-train dataset with following command. | ||
# The number after config file represents the number of GPUs used. Here we use 8 GPUs. | ||
bash tools/dist_train.sh configs/sort/faster-rcnn_r50_fpn_8xb2-4e_mot17halftrain_test-mot17halfval.py 8 | ||
``` | ||
|
||
### 2. Testing and evaluation | ||
|
||
**2.1 Example on MOTxx-halfval dataset** | ||
|
||
```shell script | ||
# Example 1: Test on motXX-half-val set. | ||
# The number after config file represents the number of GPUs used. Here we use 8 GPUs. | ||
bash tools/dist_test_tracking.sh configs/deepsort/deepsort_faster-rcnn_r50_fpn_8xb2-4e_mot17halftrain_test-mot17halfval.py 8 --detector ${DETECTOR_CHECKPOINT_PATH} --reid ${REID_CHECKPOINT_PATH} | ||
``` | ||
|
||
**2.2 Example on MOTxx-test dataset** | ||
|
||
If you want to get the results of the [MOT Challenge](https://motchallenge.net/) test set, | ||
please use the following command to generate result files that can be used for submission. | ||
It will be stored in `./mot_17_test_res`, you can modify the saved path in `test_evaluator` of the config. | ||
|
||
```shell script | ||
# Example 2: Test on motxx-test set | ||
# The number after config file represents the number of GPUs used | ||
bash tools/dist_test_tracking.sh configs/deepsort/deepsort_faster-rcnn_r50_fpn_8xb2-4e_mot17train_test-mot17test 8 --detector ${DETECTOR_CHECKPOINT_PATH} --reid ${REID_CHECKPOINT_PATH} | ||
``` | ||
|
||
### 3.Inference | ||
|
||
Use a single GPU to predict a video and save it as a video. | ||
|
||
```shell | ||
python demo/mot_demo.py demo/demo_mot.mp4 configs/deepsort/deepsort_faster-rcnn_r50_fpn_8xb2-4e_mot17train_test-mot17test --detector ${DETECTOR_CHECKPOINT_PATH} --reid ${REID_CHECKPOINT_PATH} --out mot.mp4 | ||
``` | ||
|
||
## Citation | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
```latex | ||
@inproceedings{wojke2017simple, | ||
title={Simple online and realtime tracking with a deep association metric}, | ||
author={Wojke, Nicolai and Bewley, Alex and Paulus, Dietrich}, | ||
booktitle={2017 IEEE international conference on image processing (ICIP)}, | ||
pages={3645--3649}, | ||
year={2017}, | ||
organization={IEEE} | ||
} | ||
``` |
86 changes: 86 additions & 0 deletions
86
configs/deepsort/deepsort_faster-rcnn_r50_fpn_8xb2-4e_mot17halftrain_test-mot17halfval.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
_base_ = [ | ||
'../_base_/models/faster-rcnn_r50_fpn.py', | ||
'../_base_/datasets/mot_challenge.py', '../_base_/default_runtime.py' | ||
] | ||
|
||
default_hooks = dict( | ||
logger=dict(type='LoggerHook', interval=1), | ||
visualization=dict(type='TrackVisualizationHook', draw=False)) | ||
|
||
vis_backends = [dict(type='LocalVisBackend')] | ||
visualizer = dict( | ||
type='TrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') | ||
# custom hooks | ||
custom_hooks = [ | ||
# Synchronize model buffers such as running_mean and running_var in BN | ||
# at the end of each epoch | ||
dict(type='SyncBuffersHook') | ||
] | ||
|
||
detector = _base_.model | ||
detector.pop('data_preprocessor') | ||
detector.rpn_head.bbox_coder.update(dict(clip_border=False)) | ||
detector.roi_head.bbox_head.update(dict(num_classes=1)) | ||
detector.roi_head.bbox_head.bbox_coder.update(dict(clip_border=False)) | ||
detector['init_cfg'] = dict( | ||
type='Pretrained', | ||
checkpoint= # noqa: E251 | ||
'https://download.openmmlab.com/mmtracking/mot/faster_rcnn/' | ||
'faster-rcnn_r50_fpn_4e_mot17-half-64ee2ed4.pth') | ||
del _base_.model | ||
|
||
model = dict( | ||
type='DeepSORT', | ||
data_preprocessor=dict( | ||
type='TrackDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
rgb_to_bgr=False, | ||
pad_size_divisor=32), | ||
detector=detector, | ||
reid=dict( | ||
type='BaseReID', | ||
data_preprocessor=None, | ||
backbone=dict( | ||
type='mmcls.ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(3, ), | ||
style='pytorch'), | ||
neck=dict(type='GlobalAveragePooling', kernel_size=(8, 4), stride=1), | ||
head=dict( | ||
type='LinearReIDHead', | ||
num_fcs=1, | ||
in_channels=2048, | ||
fc_channels=1024, | ||
out_channels=128, | ||
num_classes=380, | ||
loss_cls=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0), | ||
loss_triplet=dict(type='TripletLoss', margin=0.3, loss_weight=1.0), | ||
norm_cfg=dict(type='BN1d'), | ||
act_cfg=dict(type='ReLU')), | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint= # noqa: E251 | ||
'https://download.openmmlab.com/mmtracking/mot/reid/tracktor_reid_r50_iter25245-a452f51f.pth' # noqa: E501 | ||
)), | ||
tracker=dict( | ||
type='SORTTracker', | ||
motion=dict(type='KalmanFilter', center_only=False), | ||
obj_score_thr=0.5, | ||
reid=dict( | ||
num_samples=10, | ||
img_scale=(256, 128), | ||
img_norm_cfg=None, | ||
match_score_thr=2.0), | ||
match_iou_thr=0.5, | ||
momentums=None, | ||
num_tentatives=2, | ||
num_frames_retain=100)) | ||
|
||
train_dataloader = None | ||
|
||
train_cfg = None | ||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') |
22 changes: 22 additions & 0 deletions
22
configs/deepsort/deepsort_faster-rcnn_r50_fpn_8xb2-4e_mot17train_test-mot17test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
_base_ = [ | ||
'./deepsort_faster-rcnn_r50_fpn_8xb2-4e_mot17halftrain' | ||
'_test-mot17halfval.py' | ||
] | ||
model = dict( | ||
detector=dict( | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint= # noqa: E251 | ||
'https://download.openmmlab.com/mmtracking/mot/faster_rcnn/faster-rcnn_r50_fpn_4e_mot17-ffa52ae7.pth' # noqa: E501 | ||
))) | ||
|
||
# dataloader | ||
val_dataloader = dict( | ||
dataset=dict(ann_file='annotations/train_cocoformat.json')) | ||
test_dataloader = dict( | ||
dataset=dict( | ||
ann_file='annotations/test_cocoformat.json', | ||
data_prefix=dict(img_path='test'))) | ||
|
||
# evaluator | ||
test_evaluator = dict(format_only=True, outfile_prefix='./mot_17_test_res') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
Collections: | ||
- Name: DeepSORT | ||
Metadata: | ||
Training Techniques: | ||
- SGD with Momentum | ||
Training Resources: 8x V100 GPUs | ||
Architecture: | ||
- ResNet | ||
- FPN | ||
Paper: | ||
URL: https://arxiv.org/abs/1703.07402 | ||
Title: Simple Online and Realtime Tracking with a Deep Association Metric | ||
README: configs/mot/deepsort/README.md | ||
|
||
Models: | ||
- Name: deepsort_faster-rcnn_r50_fpn_8xb2-4e_mot17halftrain_test-mot17halfval | ||
In Collection: DeepSORT | ||
Config: configs/deepsort/deepsort_faster-rcnn_r50_fpn_8xb2-4e_mot17halftrain_test-mot17halfval.py | ||
Metadata: | ||
Training Data: MOT17-half-train | ||
inference time (ms/im): | ||
- value: 72.5 | ||
hardware: V100 | ||
backend: PyTorch | ||
batch size: 1 | ||
mode: FP32 | ||
resolution: (640, 1088) | ||
Results: | ||
- Task: Multiple Object Tracking | ||
Dataset: MOT17-half-val | ||
Metrics: | ||
MOTA: 63.7 | ||
IDF1: 69.5 | ||
HOTA: 57.0 | ||
Weights: | ||
- https://download.openmmlab.com/mmtracking/mot/faster_rcnn/faster-rcnn_r50_fpn_4e_mot17-half-64ee2ed4.pth | ||
- https://download.openmmlab.com/mmtracking/mot/reid/tracktor_reid_r50_iter25245-a452f51f.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.