Skip to content

Commit

Permalink
[Refactor] refactor Associative Embedding codec (open-mmlab#1603)
Browse files Browse the repository at this point in the history
* add associative embedding codec

* refactor decoding [wip]

* refactor decoding process

* add associative embedding codec

* refactor decoding refinements

* add missing keypoint complement and unit test

* support dynamic input_size in decoding

* add unit test for decoding with dynamic size
  • Loading branch information
ly015 committed Oct 14, 2022
1 parent 19d20a7 commit 557f0d6
Show file tree
Hide file tree
Showing 12 changed files with 1,110 additions and 167 deletions.
3 changes: 2 additions & 1 deletion mmpose/codecs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .associative_embedding import AssociativeEmbedding
from .megvii_heatmap import MegviiHeatmap
from .msra_heatmap import MSRAHeatmap
from .regression_label import RegressionLabel
Expand All @@ -7,5 +8,5 @@

__all__ = [
'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel',
'SimCCLabel'
'SimCCLabel', 'AssociativeEmbedding'
]
521 changes: 521 additions & 0 deletions mmpose/codecs/associative_embedding.py

Large diffs are not rendered by default.

70 changes: 11 additions & 59 deletions mmpose/codecs/msra_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from mmpose.registry import KEYPOINT_CODECS
from .base import BaseKeypointCodec
from .utils import (gaussian_blur, generate_gaussian_heatmaps,
get_heatmap_maximum)
from .utils.gaussian_heatmap import generate_unbiased_gaussian_heatmaps
from .utils.gaussian_heatmap import (generate_gaussian_heatmaps,
generate_unbiased_gaussian_heatmaps)
from .utils.post_processing import get_heatmap_maximum
from .utils.refinement import refine_keypoints, refine_keypoints_dark


@KEYPOINT_CODECS.register_module()
Expand Down Expand Up @@ -126,69 +127,20 @@ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:

keypoints, scores = get_heatmap_maximum(heatmaps)

# Unsqueeze the instance dimension for single-instance results
keypoints, scores = keypoints[None], scores[None]

if self.unbiased:
# Alleviate biased coordinate
# Apply Gaussian distribution modulation.
heatmaps = gaussian_blur(heatmaps, kernel=self.blur_kernel_size)
heatmaps = np.log(np.maximum(heatmaps, 1e-10))
for k in range(K):
keypoints[k] = self._taylor_decode(
heatmap=heatmaps[k], keypoint=keypoints[k])
keypoints = refine_keypoints_dark(
keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size)

else:
# Add +/-0.25 shift to the predicted locations for higher acc.
for k in range(K):
heatmap = heatmaps[k]
px = int(keypoints[k, 0])
py = int(keypoints[k, 1])
if 1 < px < W - 1 and 1 < py < H - 1:
diff = np.array([
heatmap[py][px + 1] - heatmap[py][px - 1],
heatmap[py + 1][px] - heatmap[py - 1][px]
])
keypoints[k] += np.sign(diff) * 0.25
keypoints = refine_keypoints(keypoints, heatmaps)

# Unsqueeze the instance dimension for single-instance results
# and restore the keypoint scales
keypoints = keypoints[None] * self.scale_factor
scores = scores[None]

return keypoints, scores

@staticmethod
def _taylor_decode(heatmap: np.ndarray,
keypoint: np.ndarray) -> np.ndarray:
"""Distribution aware coordinate decoding for a single keypoint.
Note:
- heatmap height: H
- heatmap width: W
Args:
heatmap (np.ndarray[H, W]): Heatmap of a particular keypoint type.
keypoint (np.ndarray[2,]): Coordinates of the predicted keypoint.
Returns:
np.ndarray[2,]: Updated coordinates.
"""
H, W = heatmap.shape[:2]
px, py = int(keypoint[0]), int(keypoint[1])
if 1 < px < W - 2 and 1 < py < H - 2:
dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1])
dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px])
dxx = 0.25 * (
heatmap[py][px + 2] - 2 * heatmap[py][px] +
heatmap[py][px - 2])
dxy = 0.25 * (
heatmap[py + 1][px + 1] - heatmap[py - 1][px + 1] -
heatmap[py + 1][px - 1] + heatmap[py - 1][px - 1])
dyy = 0.25 * (
heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] +
heatmap[py - 2 * 1][px])
derivative = np.array([[dx], [dy]])
hessian = np.array([[dxx, dxy], [dxy, dyy]])
if dxx * dyy - dxy**2 != 0:
hessianinv = np.linalg.inv(hessian)
offset = -hessianinv @ derivative
offset = np.squeeze(np.array(offset.T), axis=0)
keypoint += offset
return keypoint
69 changes: 12 additions & 57 deletions mmpose/codecs/udp_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mmpose.registry import KEYPOINT_CODECS
from .base import BaseKeypointCodec
from .utils import (generate_offset_heatmap, generate_udp_gaussian_heatmaps,
get_heatmap_maximum)
get_heatmap_maximum, refine_keypoints_dark_udp)


@KEYPOINT_CODECS.register_module()
Expand Down Expand Up @@ -139,8 +139,13 @@ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:

if self.heatmap_type == 'gaussian':
keypoints, scores = get_heatmap_maximum(heatmaps)
keypoints = self._postprocess_dark_udp(heatmaps, keypoints,
self.blur_kernel_size)
# unsqueeze the instance dimension for single-instance results
keypoints = keypoints[None]
scores = scores[None]

keypoints = refine_keypoints_dark_udp(
keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size)

elif self.heatmap_type == 'combined':
_K, H, W = heatmaps.shape
K = _K // 3
Expand All @@ -163,61 +168,11 @@ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
index += W * H * np.arange(0, K)
index = index.astype(int)
keypoints += np.stack((x_offset[index], y_offset[index]), axis=-1)
# unsqueeze the instance dimension for single-instance results
keypoints = keypoints[None].astype(np.float32)
scores = scores[None]

# Unsqueeze the instance dimension for single-instance results
W, H = self.heatmap_size
keypoints = keypoints[None] / [W - 1, H - 1] * self.input_size
scores = scores[None]
keypoints = keypoints / [W - 1, H - 1] * self.input_size

return keypoints, scores

@staticmethod
def _postprocess_dark_udp(heatmaps: np.ndarray, keypoints: np.ndarray,
kernel_size: int) -> np.ndarray:
"""Distribution aware post-processing for UDP.
Args:
heatmaps (np.ndarray): Heatmaps in shape (K, H, W)
keypoints (np.ndarray): Keypoint coordinates in shape (K, D)
kernel_size (int): The Gaussian blur kernel size of the heatmap
modulation
Returns:
np.ndarray: Post-processed keypoint coordinates
"""
K, H, W = heatmaps.shape

for k in range(K):
cv2.GaussianBlur(heatmaps[k], (kernel_size, kernel_size), 0,
heatmaps[k])

np.clip(heatmaps, 0.001, 50., heatmaps)
np.log(heatmaps, heatmaps)
heatmaps_pad = np.pad(
heatmaps, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten()

index = keypoints[..., 0] + 1 + (keypoints[..., 1] + 1) * (W + 2)
index += (W + 2) * (H + 2) * np.arange(0, K)
index = index.astype(int).reshape(-1, 1)
i_ = heatmaps_pad[index]
ix1 = heatmaps_pad[index + 1]
iy1 = heatmaps_pad[index + W + 2]
ix1y1 = heatmaps_pad[index + W + 3]
ix1_y1_ = heatmaps_pad[index - W - 3]
ix1_ = heatmaps_pad[index - 1]
iy1_ = heatmaps_pad[index - 2 - W]

dx = 0.5 * (ix1 - ix1_)
dy = 0.5 * (iy1 - iy1_)
derivative = np.concatenate([dx, dy], axis=1)
derivative = derivative.reshape(K, 2, 1)

dxx = ix1 - 2 * i_ + ix1_
dyy = iy1 - 2 * i_ + iy1_
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1)
hessian = hessian.reshape(K, 2, 2)
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze()

return keypoints
10 changes: 7 additions & 3 deletions mmpose/codecs/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
generate_udp_gaussian_heatmaps,
generate_unbiased_gaussian_heatmaps)
from .offset_heatmap import generate_offset_heatmap
from .post_processing import (gaussian_blur, get_heatmap_maximum,
get_simcc_maximum)
from .post_processing import (batch_heatmap_nms, gaussian_blur,
get_heatmap_maximum, get_simcc_maximum)
from .refinement import (refine_keypoints, refine_keypoints_dark,
refine_keypoints_dark_udp)

__all__ = [
'generate_gaussian_heatmaps', 'generate_udp_gaussian_heatmaps',
'generate_unbiased_gaussian_heatmaps', 'gaussian_blur',
'get_heatmap_maximum', 'get_simcc_maximum', 'generate_offset_heatmap'
'get_heatmap_maximum', 'get_simcc_maximum', 'generate_offset_heatmap',
'batch_heatmap_nms', 'refine_keypoints', 'refine_keypoints_dark',
'refine_keypoints_dark_udp'
]
28 changes: 28 additions & 0 deletions mmpose/codecs/utils/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor


def get_simcc_maximum(simcc_x: np.ndarray,
Expand Down Expand Up @@ -136,3 +139,28 @@ def gaussian_blur(heatmaps: np.ndarray, kernel: int = 11) -> np.ndarray:
heatmaps[k] = dr[border:-border, border:-border].copy()
heatmaps[k] *= origin_max / np.max(heatmaps[k])
return heatmaps


def batch_heatmap_nms(batch_heatmaps: Tensor, kernel_size: int = 5):
"""Apply NMS on a batch of heatmaps.
Args:
batch_heatmaps (Tensor): batch heatmaps in shape (B, K, H, W)
kernel_size (int): The kernel size of the NMS which should be
a odd integer. Defaults to 5
Returns:
Tensor: The batch heatmaps after NMS.
"""

assert isinstance(kernel_size, int) and kernel_size % 2 == 1, \
f'The kernel_size should be an odd integer, got {kernel_size}'

padding = (kernel_size - 1) // 2

maximum = F.max_pool2d(
batch_heatmaps, kernel_size, stride=1, padding=padding)
maximum_indicator = torch.eq(batch_heatmaps, maximum)
batch_heatmaps = batch_heatmaps * maximum_indicator.float()

return batch_heatmaps
Loading

0 comments on commit 557f0d6

Please sign in to comment.