Skip to content

Commit

Permalink
Add result of the example
Browse files Browse the repository at this point in the history
  • Loading branch information
jaehyunnn committed Dec 19, 2022
1 parent 7c296dc commit e22cfbd
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ViTPose (simple version)
An unofficial implementation of ViTPose [Y. Xu et al., 2022]
An unofficial implementation of ViTPose [Y. Xu et al., 2022] <br>
![result_image](./examples/img1_result.jpg "Result Image")

## Usage
```
Expand Down
Binary file modified examples/img1_result.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 7 additions & 2 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ def inference(img_path: Path, img_size: tuple[int, int],

# Feed to model
tic = time()
heatmaps = vit_pose(img_tensor).detach().cpu().numpy()
heatmaps = vit_pose(img_tensor).detach().cpu().numpy() # N, 17, h/4, w/4
elapsed_time = time()-tic
print(f">>> Output size: {heatmaps.shape} ---> {elapsed_time:.4f} sec. elapsed [{elapsed_time**-1: .1f} fps]\n")



kpts, probs = keypoints_from_heatmaps(heatmaps=heatmaps,
center=np.array([[org_w//2, org_h//2]]), # x, y
scale=np.array([[org_h/img_size[1], org_w/img_size[0]]]), # h, w
Expand All @@ -54,7 +56,8 @@ def inference(img_path: Path, img_size: tuple[int, int],
valid_radius_factor=0.0546875,
use_udp=False,
target_type='GaussianHeatmap')
points = np.concatenate([kpts[:, :, ::-1], probs], axis=2) # batch, num_people, (2+1)
print(kpts.shape, probs.shape)
points = np.concatenate([kpts[:, :, ::-1], probs], axis=2) # N, 17, (2+1)

# Visualization
if save_result:
Expand All @@ -81,6 +84,8 @@ def inference(img_path: Path, img_size: tuple[int, int],
CKPT_PATH = f"{CUR_DIR}/vitpose-b-multi-coco.pth"

img_size = data_cfg['image_size']
if type(args.image_path) != list:
args.image_path = [args.image_path]
for img_path in args.image_path:
print(img_path)
keypoints = inference(img_path=img_path, img_size=img_size, model_cfg=model_cfg, ckpt_path=CKPT_PATH,
Expand Down

0 comments on commit e22cfbd

Please sign in to comment.