Skip to content

Commit

Permalink
[Refactor] Support DEKR (open-mmlab#1834)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Dec 7, 2022
1 parent 62fd9d0 commit 8e61de6
Show file tree
Hide file tree
Showing 29 changed files with 2,566 additions and 29 deletions.
22 changes: 22 additions & 0 deletions configs/body_2d_keypoint/dekr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Bottom-up Human Pose Estimation via Disentangled Keypoint Regression (DEKR)

<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2104.02300">DEKR (CVPR'2021)</a></summary>

```bibtex
@inproceedings{geng2021bottom,
title={Bottom-up human pose estimation via disentangled keypoint regression},
author={Geng, Zigang and Sun, Ke and Xiao, Bin and Zhang, Zhaoxiang and Wang, Jingdong},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={14676--14686},
year={2021}
}
```

</details>

DEKR is a popular 2D bottom-up pose estimation approach that simultaneously detects all the instances and regresses the offsets from the instance centers to joints.

In order to predict the offsets more accurately, the offsets of different joints are regressed using separated branches with deformable convolutional layers. Thus convolution kernels with different shapes are adopted to extract features for the corresponding joint.
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
train_cfg = dict(max_epochs=140, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=1e-3,
))

# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=140,
milestones=[90, 120],
gamma=0.1,
by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=80)

# hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))

# codec settings
codec = dict(
type='RootDisplacement',
input_size=(512, 512),
heatmap_size=(128, 128),
sigma=(4, 2),
generate_keypoint_heatmaps=True,
decode_max_instances=30)

# model settings
model = dict(
type='BottomupPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256),
multiscale_output=True)),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmpose/'
'pretrain_models/hrnet_w32-36af842e.pth'),
),
head=dict(
type='DEKRHead',
in_channels=(32, 64, 128, 256),
num_keypoints=17,
input_transform='resize_concat',
input_index=(0, 1, 2, 3),
heatmap_loss=dict(type='KeypointMSELoss', use_target_weight=True),
displacement_loss=dict(
type='SoftWeightSmoothL1Loss',
use_target_weight=True,
supervise_empty=False,
beta=1 / 9,
loss_weight=0.002,
),
decoder=codec,
rescore_cfg=dict(
in_channels=74,
norm_indexes=(5, 6),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmpose/'
'pretrain_models/kpt_rescore_coco-33d58c5c.pth')),
),
test_cfg=dict(
multiscale_test=False,
flip_test=True,
nms_dist_thr=0.05,
shift_heatmap=False,
align_corners=False))

# enable DDP training when rescore net is used
find_unused_parameters = True

# base dataset settings
dataset_type = 'CocoDataset'
data_mode = 'bottomup'
data_root = 'data/coco/'

# pipelines
train_pipeline = [
dict(type='LoadImage', file_client_args={{_base_.file_client_args}}),
dict(type='BottomupRandomAffine', input_size=codec['input_size']),
dict(type='RandomFlip', direction='horizontal'),
dict(type='GenerateTarget', encoder=codec),
dict(type='BottomupGetHeatmapMask'),
dict(type='PackPoseInputs'),
]
val_pipeline = [
dict(type='LoadImage', file_client_args={{_base_.file_client_args}}),
dict(
type='BottomupResize',
input_size=codec['input_size'],
size_factor=32,
resize_mode='expand'),
dict(
type='PackPoseInputs',
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
'img_shape', 'input_size', 'input_center', 'input_scale',
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
'skeleton_links'))
]

# data loaders
train_dataloader = dict(
batch_size=10,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/person_keypoints_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline=train_pipeline,
))
val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/person_keypoints_val2017.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader

# evaluators
val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'annotations/person_keypoints_val2017.json',
nms_mode='none',
score_mode='keypoint',
)
test_evaluator = val_evaluator
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
train_cfg = dict(max_epochs=140, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=1e-3,
))

# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=140,
milestones=[90, 120],
gamma=0.1,
by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=40)

# hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))

# codec settings
codec = dict(
type='RootDisplacement',
input_size=(640, 640),
heatmap_size=(160, 160),
sigma=(4, 2),
generate_keypoint_heatmaps=True,
decode_max_instances=30)

# model settings
model = dict(
type='BottomupPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(48, 96)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(48, 96, 192)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(48, 96, 192, 384),
multiscale_output=True)),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmpose/'
'pretrain_models/hrnet_w32-36af842e.pth'),
),
head=dict(
type='DEKRHead',
in_channels=(48, 96, 192, 384),
num_keypoints=17,
input_transform='resize_concat',
input_index=(0, 1, 2, 3),
num_heatmap_filters=48,
heatmap_loss=dict(type='KeypointMSELoss', use_target_weight=True),
displacement_loss=dict(
type='SoftWeightSmoothL1Loss',
use_target_weight=True,
supervise_empty=False,
beta=1 / 9,
loss_weight=0.002,
),
decoder=codec,
rescore_cfg=dict(
in_channels=74,
norm_indexes=(5, 6),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmpose/'
'pretrain_models/kpt_rescore_coco-33d58c5c.pth')),
),
test_cfg=dict(
multiscale_test=False,
flip_test=True,
nms_dist_thr=0.05,
shift_heatmap=False,
align_corners=False))

# enable DDP training when rescore net is used
find_unused_parameters = True

# base dataset settings
dataset_type = 'CocoDataset'
data_mode = 'bottomup'
data_root = 'data/coco/'

# pipelines
train_pipeline = [
dict(type='LoadImage', file_client_args={{_base_.file_client_args}}),
dict(type='RandomFlip', direction='horizontal'),
dict(type='BottomupRandomAffine', input_size=codec['input_size']),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs'),
]
val_pipeline = [
dict(type='LoadImage', file_client_args={{_base_.file_client_args}}),
dict(
type='BottomupResize',
input_size=codec['input_size'],
size_factor=32,
resize_mode='expand'),
dict(
type='PackPoseInputs',
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
'img_shape', 'input_size', 'input_center', 'input_scale',
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
'skeleton_links'))
]

# data loaders
train_dataloader = dict(
batch_size=5,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/person_keypoints_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline=train_pipeline,
))
val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/person_keypoints_val2017.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader

# evaluators
val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'annotations/person_keypoints_val2017.json',
nms_mode='none',
score_mode='keypoint',
)
test_evaluator = val_evaluator
Loading

0 comments on commit 8e61de6

Please sign in to comment.