forked from open-mmlab/mmpose
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor] Support DEKR (open-mmlab#1834)
- Loading branch information
Showing
29 changed files
with
2,566 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Bottom-up Human Pose Estimation via Disentangled Keypoint Regression (DEKR) | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
<details> | ||
<summary align="right"><a href="https://arxiv.org/abs/2104.02300">DEKR (CVPR'2021)</a></summary> | ||
|
||
```bibtex | ||
@inproceedings{geng2021bottom, | ||
title={Bottom-up human pose estimation via disentangled keypoint regression}, | ||
author={Geng, Zigang and Sun, Ke and Xiao, Bin and Zhang, Zhaoxiang and Wang, Jingdong}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
pages={14676--14686}, | ||
year={2021} | ||
} | ||
``` | ||
|
||
</details> | ||
|
||
DEKR is a popular 2D bottom-up pose estimation approach that simultaneously detects all the instances and regresses the offsets from the instance centers to joints. | ||
|
||
In order to predict the offsets more accurately, the offsets of different joints are regressed using separated branches with deformable convolutional layers. Thus convolution kernels with different shapes are adopted to extract features for the corresponding joint. |
183 changes: 183 additions & 0 deletions
183
configs/body_2d_keypoint/dekr/coco/dekr_hrnet-w32_8xb10-140e_coco-512x512.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
_base_ = ['../../../_base_/default_runtime.py'] | ||
|
||
# runtime | ||
train_cfg = dict(max_epochs=140, val_interval=10) | ||
|
||
# optimizer | ||
optim_wrapper = dict(optimizer=dict( | ||
type='Adam', | ||
lr=1e-3, | ||
)) | ||
|
||
# learning policy | ||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', begin=0, end=500, start_factor=0.001, | ||
by_epoch=False), # warm-up | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=140, | ||
milestones=[90, 120], | ||
gamma=0.1, | ||
by_epoch=True) | ||
] | ||
|
||
# automatically scaling LR based on the actual training batch size | ||
auto_scale_lr = dict(base_batch_size=80) | ||
|
||
# hooks | ||
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) | ||
|
||
# codec settings | ||
codec = dict( | ||
type='RootDisplacement', | ||
input_size=(512, 512), | ||
heatmap_size=(128, 128), | ||
sigma=(4, 2), | ||
generate_keypoint_heatmaps=True, | ||
decode_max_instances=30) | ||
|
||
# model settings | ||
model = dict( | ||
type='BottomupPoseEstimator', | ||
data_preprocessor=dict( | ||
type='PoseDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True), | ||
backbone=dict( | ||
type='HRNet', | ||
in_channels=3, | ||
extra=dict( | ||
stage1=dict( | ||
num_modules=1, | ||
num_branches=1, | ||
block='BOTTLENECK', | ||
num_blocks=(4, ), | ||
num_channels=(64, )), | ||
stage2=dict( | ||
num_modules=1, | ||
num_branches=2, | ||
block='BASIC', | ||
num_blocks=(4, 4), | ||
num_channels=(32, 64)), | ||
stage3=dict( | ||
num_modules=4, | ||
num_branches=3, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4), | ||
num_channels=(32, 64, 128)), | ||
stage4=dict( | ||
num_modules=3, | ||
num_branches=4, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4, 4), | ||
num_channels=(32, 64, 128, 256), | ||
multiscale_output=True)), | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint='https://download.openmmlab.com/mmpose/' | ||
'pretrain_models/hrnet_w32-36af842e.pth'), | ||
), | ||
head=dict( | ||
type='DEKRHead', | ||
in_channels=(32, 64, 128, 256), | ||
num_keypoints=17, | ||
input_transform='resize_concat', | ||
input_index=(0, 1, 2, 3), | ||
heatmap_loss=dict(type='KeypointMSELoss', use_target_weight=True), | ||
displacement_loss=dict( | ||
type='SoftWeightSmoothL1Loss', | ||
use_target_weight=True, | ||
supervise_empty=False, | ||
beta=1 / 9, | ||
loss_weight=0.002, | ||
), | ||
decoder=codec, | ||
rescore_cfg=dict( | ||
in_channels=74, | ||
norm_indexes=(5, 6), | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint='https://download.openmmlab.com/mmpose/' | ||
'pretrain_models/kpt_rescore_coco-33d58c5c.pth')), | ||
), | ||
test_cfg=dict( | ||
multiscale_test=False, | ||
flip_test=True, | ||
nms_dist_thr=0.05, | ||
shift_heatmap=False, | ||
align_corners=False)) | ||
|
||
# enable DDP training when rescore net is used | ||
find_unused_parameters = True | ||
|
||
# base dataset settings | ||
dataset_type = 'CocoDataset' | ||
data_mode = 'bottomup' | ||
data_root = 'data/coco/' | ||
|
||
# pipelines | ||
train_pipeline = [ | ||
dict(type='LoadImage', file_client_args={{_base_.file_client_args}}), | ||
dict(type='BottomupRandomAffine', input_size=codec['input_size']), | ||
dict(type='RandomFlip', direction='horizontal'), | ||
dict(type='GenerateTarget', encoder=codec), | ||
dict(type='BottomupGetHeatmapMask'), | ||
dict(type='PackPoseInputs'), | ||
] | ||
val_pipeline = [ | ||
dict(type='LoadImage', file_client_args={{_base_.file_client_args}}), | ||
dict( | ||
type='BottomupResize', | ||
input_size=codec['input_size'], | ||
size_factor=32, | ||
resize_mode='expand'), | ||
dict( | ||
type='PackPoseInputs', | ||
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape', | ||
'img_shape', 'input_size', 'input_center', 'input_scale', | ||
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info', | ||
'skeleton_links')) | ||
] | ||
|
||
# data loaders | ||
train_dataloader = dict( | ||
batch_size=10, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_mode=data_mode, | ||
ann_file='annotations/person_keypoints_train2017.json', | ||
data_prefix=dict(img='train2017/'), | ||
pipeline=train_pipeline, | ||
)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=1, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_mode=data_mode, | ||
ann_file='annotations/person_keypoints_val2017.json', | ||
data_prefix=dict(img='val2017/'), | ||
test_mode=True, | ||
pipeline=val_pipeline, | ||
)) | ||
test_dataloader = val_dataloader | ||
|
||
# evaluators | ||
val_evaluator = dict( | ||
type='CocoMetric', | ||
ann_file=data_root + 'annotations/person_keypoints_val2017.json', | ||
nms_mode='none', | ||
score_mode='keypoint', | ||
) | ||
test_evaluator = val_evaluator |
183 changes: 183 additions & 0 deletions
183
configs/body_2d_keypoint/dekr/coco/dekr_hrnet-w48_8xb5-140e_coco-640x640.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
_base_ = ['../../../_base_/default_runtime.py'] | ||
|
||
# runtime | ||
train_cfg = dict(max_epochs=140, val_interval=10) | ||
|
||
# optimizer | ||
optim_wrapper = dict(optimizer=dict( | ||
type='Adam', | ||
lr=1e-3, | ||
)) | ||
|
||
# learning policy | ||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', begin=0, end=500, start_factor=0.001, | ||
by_epoch=False), # warm-up | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=140, | ||
milestones=[90, 120], | ||
gamma=0.1, | ||
by_epoch=True) | ||
] | ||
|
||
# automatically scaling LR based on the actual training batch size | ||
auto_scale_lr = dict(base_batch_size=40) | ||
|
||
# hooks | ||
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) | ||
|
||
# codec settings | ||
codec = dict( | ||
type='RootDisplacement', | ||
input_size=(640, 640), | ||
heatmap_size=(160, 160), | ||
sigma=(4, 2), | ||
generate_keypoint_heatmaps=True, | ||
decode_max_instances=30) | ||
|
||
# model settings | ||
model = dict( | ||
type='BottomupPoseEstimator', | ||
data_preprocessor=dict( | ||
type='PoseDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True), | ||
backbone=dict( | ||
type='HRNet', | ||
in_channels=3, | ||
extra=dict( | ||
stage1=dict( | ||
num_modules=1, | ||
num_branches=1, | ||
block='BOTTLENECK', | ||
num_blocks=(4, ), | ||
num_channels=(64, )), | ||
stage2=dict( | ||
num_modules=1, | ||
num_branches=2, | ||
block='BASIC', | ||
num_blocks=(4, 4), | ||
num_channels=(48, 96)), | ||
stage3=dict( | ||
num_modules=4, | ||
num_branches=3, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4), | ||
num_channels=(48, 96, 192)), | ||
stage4=dict( | ||
num_modules=3, | ||
num_branches=4, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4, 4), | ||
num_channels=(48, 96, 192, 384), | ||
multiscale_output=True)), | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint='https://download.openmmlab.com/mmpose/' | ||
'pretrain_models/hrnet_w32-36af842e.pth'), | ||
), | ||
head=dict( | ||
type='DEKRHead', | ||
in_channels=(48, 96, 192, 384), | ||
num_keypoints=17, | ||
input_transform='resize_concat', | ||
input_index=(0, 1, 2, 3), | ||
num_heatmap_filters=48, | ||
heatmap_loss=dict(type='KeypointMSELoss', use_target_weight=True), | ||
displacement_loss=dict( | ||
type='SoftWeightSmoothL1Loss', | ||
use_target_weight=True, | ||
supervise_empty=False, | ||
beta=1 / 9, | ||
loss_weight=0.002, | ||
), | ||
decoder=codec, | ||
rescore_cfg=dict( | ||
in_channels=74, | ||
norm_indexes=(5, 6), | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint='https://download.openmmlab.com/mmpose/' | ||
'pretrain_models/kpt_rescore_coco-33d58c5c.pth')), | ||
), | ||
test_cfg=dict( | ||
multiscale_test=False, | ||
flip_test=True, | ||
nms_dist_thr=0.05, | ||
shift_heatmap=False, | ||
align_corners=False)) | ||
|
||
# enable DDP training when rescore net is used | ||
find_unused_parameters = True | ||
|
||
# base dataset settings | ||
dataset_type = 'CocoDataset' | ||
data_mode = 'bottomup' | ||
data_root = 'data/coco/' | ||
|
||
# pipelines | ||
train_pipeline = [ | ||
dict(type='LoadImage', file_client_args={{_base_.file_client_args}}), | ||
dict(type='RandomFlip', direction='horizontal'), | ||
dict(type='BottomupRandomAffine', input_size=codec['input_size']), | ||
dict(type='GenerateTarget', encoder=codec), | ||
dict(type='PackPoseInputs'), | ||
] | ||
val_pipeline = [ | ||
dict(type='LoadImage', file_client_args={{_base_.file_client_args}}), | ||
dict( | ||
type='BottomupResize', | ||
input_size=codec['input_size'], | ||
size_factor=32, | ||
resize_mode='expand'), | ||
dict( | ||
type='PackPoseInputs', | ||
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape', | ||
'img_shape', 'input_size', 'input_center', 'input_scale', | ||
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info', | ||
'skeleton_links')) | ||
] | ||
|
||
# data loaders | ||
train_dataloader = dict( | ||
batch_size=5, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_mode=data_mode, | ||
ann_file='annotations/person_keypoints_train2017.json', | ||
data_prefix=dict(img='train2017/'), | ||
pipeline=train_pipeline, | ||
)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=1, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_mode=data_mode, | ||
ann_file='annotations/person_keypoints_val2017.json', | ||
data_prefix=dict(img='val2017/'), | ||
test_mode=True, | ||
pipeline=val_pipeline, | ||
)) | ||
test_dataloader = val_dataloader | ||
|
||
# evaluators | ||
val_evaluator = dict( | ||
type='CocoMetric', | ||
ann_file=data_root + 'annotations/person_keypoints_val2017.json', | ||
nms_mode='none', | ||
score_mode='keypoint', | ||
) | ||
test_evaluator = val_evaluator |
Oops, something went wrong.