Skip to content

Commit

Permalink
Add keypoint decoding process
Browse files Browse the repository at this point in the history
  • Loading branch information
Jae-Hyun Park authored and Jae-Hyun Park committed Jan 31, 2023
1 parent a985ce7 commit d07bc6f
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 29 deletions.
Binary file modified examples/img1_result.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 5 additions & 23 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,9 @@
from models.model import ViTPose
from utils.visualization import draw_points_and_skeleton, joints_dict
from utils.dist_util import get_dist_info, init_dist
from utils.top_down_eval import keypoints_from_heatmaps

__all__ = ['inference']

def heatmap2coords(heatmaps: np.ndarray, original_resolution: tuple[int, int]=(256, 192)) -> np.ndarray:
__, __, heatmap_h, heatmap_w = heatmaps.shape
output = []
for heatmap in heatmaps:
keypoint_coords = []
for joint in heatmap:
keypoint_coord = np.unravel_index(np.argmax(joint), (heatmap_h, heatmap_w))
"""
- 0: coord_y / (height//4) * bbox_height + bb_y1
- 1: coord_x / (width//4) * bbox_width + bb_x1
- 2: confidences
"""
coord_y = keypoint_coord[0] / heatmap_h*original_resolution[0]
coord_x = keypoint_coord[1] / heatmap_w*original_resolution[1]
prob = joint[keypoint_coord]
keypoint_coords.append([coord_y, coord_x, prob])
output.append(keypoint_coords)

return np.array(output).astype(float)



@torch.no_grad()
Expand All @@ -64,12 +44,14 @@ def inference(img_path: Path, img_size: tuple[int, int],

# Feed to model
tic = time()
print(vit_pose.forward_features(img_tensor).shape)
heatmaps = vit_pose(img_tensor).detach().cpu().numpy() # N, 17, h/4, w/4
elapsed_time = time()-tic
print(f">>> Output size: {heatmaps.shape} ---> {elapsed_time:.4f} sec. elapsed [{elapsed_time**-1: .1f} fps]\n")

points = heatmap2coords(heatmaps=heatmaps, original_resolution=(org_h, org_w))
# points = heatmap2coords(heatmaps=heatmaps, original_resolution=(org_h, org_w))
points, prob = keypoints_from_heatmaps(heatmaps=heatmaps, center=np.array([[org_w//2, org_h//2]]), scale=np.array([[org_w, org_h]]),
unbiased=True, use_udp=True)
points = np.concatenate([points[:, :, ::-1], prob], axis=2)

# Visualization
if save_result:
Expand Down
4 changes: 0 additions & 4 deletions models/head/topdown_heatmap_simple_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,9 @@ def get_accuracy(self, output, target, target_weight):

def forward(self, x):
"""Forward function."""
print(x.shape)
x = self._transform_inputs(x)
print(x.shape)
x = self.deconv_layers(x)
print(x.shape)
x = self.final_layer(x)
print(x.shape)
return x

def inference_model(self, x, flip_pairs=None):
Expand Down
2 changes: 1 addition & 1 deletion utils/post_processing/post_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def transform_preds(coords, center, scale, output_size, use_udp=False):
assert len(output_size) == 2

# Recover the scale which is normalized by a factor of 200.
scale = scale * 200.0
# scale = scale * 200.0

if use_udp:
scale_x = scale[0] / (output_size[0] - 1.0)
Expand Down
20 changes: 19 additions & 1 deletion utils/top_down_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,25 @@

from .post_processing import transform_preds


def heatmap2coords(heatmaps: np.ndarray, original_resolution: tuple[int, int]=(256, 192)) -> np.ndarray:
__, __, heatmap_h, heatmap_w = heatmaps.shape
output = []
for heatmap in heatmaps:
keypoint_coords = []
for joint in heatmap:
keypoint_coord = np.unravel_index(np.argmax(joint), (heatmap_h, heatmap_w))
"""
- 0: coord_y / (height//4) * bbox_height + bb_y1
- 1: coord_x / (width//4) * bbox_width + bb_x1
- 2: confidences
"""
coord_y = keypoint_coord[0] / heatmap_h*original_resolution[0]
coord_x = keypoint_coord[1] / heatmap_w*original_resolution[1]
prob = joint[keypoint_coord]
keypoint_coords.append([coord_y, coord_x, prob])
output.append(keypoint_coords)

return np.array(output).astype(float)

def _calc_distances(preds, targets, mask, normalize):
"""Calculate the normalized distances between preds and target.
Expand Down

0 comments on commit d07bc6f

Please sign in to comment.