Skip to content

Commit

Permalink
Use new Mediapipe API in human pose tracking example (rerun-io#5903)
Browse files Browse the repository at this point in the history
### What

Closes rerun-io#5859.

### Checklist
* [x] I have read and agree to [Contributor
Guide](https://github.com/rerun-io/rerun/blob/main/CONTRIBUTING.md) and
the [Code of
Conduct](https://github.com/rerun-io/rerun/blob/main/CODE_OF_CONDUCT.md)
* [x] I've included a screenshot or gif (if applicable)
* [x] I have tested the web demo (if applicable):
* Using newly built examples:
[rerun.io/viewer](https://rerun.io/viewer/pr/5903)
* Using examples from latest `main` build:
[rerun.io/viewer](https://rerun.io/viewer/pr/5903?manifest_url=https://app.rerun.io/version/main/examples_manifest.json)
* Using full set of examples from `nightly` build:
[rerun.io/viewer](https://rerun.io/viewer/pr/5903?manifest_url=https://app.rerun.io/version/nightly/examples_manifest.json)
* [x] The PR title and labels are set such as to maximize their
usefulness for the next release's CHANGELOG
* [x] If applicable, add a new check to the [release
checklist](https://github.com/rerun-io/rerun/blob/main/tests/python/release_checklist)!

- [PR Build Summary](https://build.rerun.io/pr/5903)
- [Recent benchmark results](https://build.rerun.io/graphs/crates.html)
- [Wasm size tracking](https://build.rerun.io/graphs/sizes.html)
  • Loading branch information
roym899 authored Apr 11, 2024
1 parent 171aad7 commit a8cf9ec
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 34 deletions.
2 changes: 1 addition & 1 deletion examples/python/face_tracking/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# no 3.12 version yet (https://pypi.org/project/mediapipe/)
# 0.10.10 no longer supports the legacy Pose model: https://github.com/rerun-io/rerun/issues/5859
mediapipe==0.10.9 ; python_version <= '3.11'
mediapipe==0.10.11 ; python_version <= '3.11'

numpy
opencv-python>4.6 # Avoid opencv-4.6 since it rotates images incorrectly (https://github.com/opencv/opencv/issues/22088)
Expand Down
2 changes: 1 addition & 1 deletion examples/python/gesture_detection/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# no 3.12 version yet (https://pypi.org/project/mediapipe/)
# 0.10.10 no longer supports the legacy Pose model: https://github.com/rerun-io/rerun/issues/5859
mediapipe==0.10.9 ; python_version <= '3.11'
mediapipe==0.10.11 ; python_version <= '3.11'

numpy
opencv-python>4.9
Expand Down
21 changes: 8 additions & 13 deletions examples/python/human_pose_tracking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ rr.log(
### Segmentation mask

The segmentation result is logged through a combination of two archetypes. The segmentation
image itself is logged as an
image itself is logged as a
[`SegmentationImage`](https://www.rerun.io/docs/reference/types/archetypes/segmentation_image) and
contains the id for each pixel. The color is determined by the
[`AnnotationContext`](https://www.rerun.io/docs/reference/types/archetypes/annotation_context) which is
Expand All @@ -77,22 +77,15 @@ rr.log(
#### Segmentation image

```python
rr.log(
"video/mask",
rr.SegmentationImage(segmentation_mask.astype(np.uint8))
)
rr.log("video/mask", rr.SegmentationImage(binary_segmentation_mask.astype(np.uint8)))
```

### Body pose points
Logging the body pose landmarks involves specifying connections between the points, extracting pose landmark points and logging them to the Rerun SDK.
The 2D points are visualized over the image/video for a better understanding and visualization of the body pose. The 3D points allows the creation of a 3D model of the body posture for a more comprehensive representation of the human pose.


Logging the body pose as a skeleton involves specifying the connectivity of its keypoints (i.e., pose landmarks), extracting the pose landmarks, and logging them as points to Rerun. In this example, both the 2D and 3D estimates from Mediapipe are visualized.

The 2D and 3D points are logged through a combination of two archetypes. First, a timeless
The skeletons are logged through a combination of two archetypes. First, a timeless
[`ClassDescription`](https://www.rerun.io/docs/reference/types/datatypes/class_description) is logged, that contains the information which maps keypoint ids to labels and how to connect
the keypoints.
Defining these connections automatically renders lines between them. Mediapipe provides the `POSE_CONNECTIONS` variable which contains the list of `(from, to)` landmark indices that define the connections. Second, the actual keypoint positions are logged in 2D
the keypoints. By defining these connections Rerun will automatically add lines between them. Mediapipe provides the `POSE_CONNECTIONS` variable which contains the list of `(from, to)` landmark indices that define the connections. Second, the actual keypoint positions are logged in 2D
and 3D as [`Points2D`](https://www.rerun.io/docs/reference/types/archetypes/points2d) and
[`Points3D`](https://www.rerun.io/docs/reference/types/archetypes/points3d) archetypes, respectively.

Expand All @@ -104,7 +97,9 @@ rr.log(
rr.AnnotationContext(
rr.ClassDescription(
info=rr.AnnotationInfo(id=1, label="Person"),
keypoint_annotations=[rr.AnnotationInfo(id=lm.value, label=lm.name) for lm in mp_pose.PoseLandmark],
keypoint_annotations=[
rr.AnnotationInfo(id=lm.value, label=lm.name) for lm in mp_pose.PoseLandmark
],
keypoint_connections=mp_pose.POSE_CONNECTIONS,
)
),
Expand Down
81 changes: 64 additions & 17 deletions examples/python/human_pose_tracking/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import cv2
import mediapipe as mp
import mediapipe.python.solutions.pose as mp_pose
import numpy as np
import numpy.typing as npt
import requests
Expand All @@ -30,11 +31,19 @@

EXAMPLE_DIR: Final = Path(os.path.dirname(__file__))
DATASET_DIR: Final = EXAMPLE_DIR / "dataset" / "pose_movement"
MODEL_DIR: Final = EXAMPLE_DIR / "model" / "pose_movement"
DATASET_URL_BASE: Final = "https://storage.googleapis.com/rerun-example-datasets/pose_movement"
MODEL_URL_TEMPLATE: Final = "https://storage.googleapis.com/mediapipe-models/pose_landmarker/pose_landmarker_{model_name}/float16/latest/pose_landmarker_{model_name}.task"


def track_pose(video_path: str, *, segment: bool, max_frame_count: int | None) -> None:
mp_pose = mp.solutions.pose
def track_pose(video_path: str, model_path: str, *, segment: bool, max_frame_count: int | None) -> None:
options = mp.tasks.vision.PoseLandmarkerOptions(
base_options=mp.tasks.BaseOptions(
model_asset_path=model_path,
),
running_mode=mp.tasks.vision.RunningMode.VIDEO,
output_segmentation_masks=True,
)

rr.log("description", rr.TextDocument(DESCRIPTION, media_type=rr.MediaType.MARKDOWN), static=True)

Expand Down Expand Up @@ -62,19 +71,23 @@ def track_pose(video_path: str, *, segment: bool, max_frame_count: int | None) -
)
rr.log("person", rr.ViewCoordinates.RIGHT_HAND_Y_DOWN, static=True)

with closing(VideoSource(video_path)) as video_source, mp_pose.Pose(enable_segmentation=segment) as pose:
pose_landmarker = mp.tasks.vision.PoseLandmarker.create_from_options(options)

with closing(VideoSource(video_path)) as video_source:
for idx, bgr_frame in enumerate(video_source.stream_bgr()):
if max_frame_count is not None and idx >= max_frame_count:
break

mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=bgr_frame.data)
rgb = cv2.cvtColor(bgr_frame.data, cv2.COLOR_BGR2RGB)
rr.set_time_seconds("time", bgr_frame.time)
rr.set_time_sequence("frame_idx", bgr_frame.idx)
rr.log("video/rgb", rr.Image(rgb).compress(jpeg_quality=75))

results = pose.process(rgb)
results = pose_landmarker.detect_for_video(mp_image, int(bgr_frame.time * 1000))
h, w, _ = rgb.shape
landmark_positions_2d = read_landmark_positions_2d(results, w, h)

rr.log("video/rgb", rr.Image(rgb).compress(jpeg_quality=75))
if landmark_positions_2d is not None:
rr.log(
"video/pose/points",
Expand All @@ -88,30 +101,33 @@ def track_pose(video_path: str, *, segment: bool, max_frame_count: int | None) -
rr.Points3D(landmark_positions_3d, class_ids=1, keypoint_ids=mp_pose.PoseLandmark),
)

segmentation_mask = results.segmentation_mask
if segmentation_mask is not None:
rr.log("video/mask", rr.SegmentationImage(segmentation_mask.astype(np.uint8)))
if results.segmentation_masks is not None:
segmentation_mask = results.segmentation_masks[0].numpy_view()
binary_segmentation_mask = segmentation_mask > 0.5
rr.log("video/mask", rr.SegmentationImage(binary_segmentation_mask.astype(np.uint8)))


def read_landmark_positions_2d(
results: Any,
image_width: int,
image_height: int,
) -> npt.NDArray[np.float32] | None:
if results.pose_landmarks is None:
if results.pose_landmarks is None or len(results.pose_landmarks) == 0:
return None
else:
normalized_landmarks = [results.pose_landmarks.landmark[lm] for lm in mp.solutions.pose.PoseLandmark]
pose_landmarks = results.pose_landmarks[0]
normalized_landmarks = [pose_landmarks[lm] for lm in mp_pose.PoseLandmark]
return np.array([(image_width * lm.x, image_height * lm.y) for lm in normalized_landmarks])


def read_landmark_positions_3d(
results: Any,
) -> npt.NDArray[np.float32] | None:
if results.pose_landmarks is None:
if results.pose_landmarks is None or len(results.pose_landmarks) == 0:
return None
else:
landmarks = [results.pose_world_landmarks.landmark[lm] for lm in mp.solutions.pose.PoseLandmark]
pose_landmarks = results.pose_landmarks[0]
landmarks = [pose_landmarks[lm] for lm in mp_pose.PoseLandmark]
return np.array([(lm.x, lm.y, lm.z) for lm in landmarks])


Expand Down Expand Up @@ -144,7 +160,7 @@ def stream_bgr(self) -> Iterator[VideoFrame]:
yield VideoFrame(data=bgr, time=time_ms * 1e-3, idx=idx)


def get_downloaded_path(dataset_dir: Path, video_name: str) -> str:
def get_downloaded_video_path(dataset_dir: Path, video_name: str) -> str:
video_file_name = f"{video_name}.mp4"
destination_path = dataset_dir / video_file_name
if destination_path.exists():
Expand All @@ -155,12 +171,30 @@ def get_downloaded_path(dataset_dir: Path, video_name: str) -> str:

logging.info("Downloading video from %s to %s", source_path, destination_path)
os.makedirs(dataset_dir.absolute(), exist_ok=True)
with requests.get(source_path, stream=True) as req:
download(source_path, destination_path)
return str(destination_path)


def get_downloaded_model_path(model_dir: Path, model_name: str) -> str:
model_file_name = f"{model_name}.task"
destination_path = model_dir / model_file_name
if destination_path.exists():
logging.info("%s already exists. No need to download", destination_path)
return str(destination_path)

model_url = MODEL_URL_TEMPLATE.format(model_name=model_name)
logging.info("Downloading model from %s to %s", model_url, destination_path)
download(model_url, destination_path)
return str(destination_path)


def download(url: str, destination_path: Path) -> None:
os.makedirs(destination_path.parent, exist_ok=True)
with requests.get(url, stream=True) as req:
req.raise_for_status()
with open(destination_path, "wb") as f:
for chunk in req.iter_content(chunk_size=8192):
f.write(chunk)
return str(destination_path)


def main() -> None:
Expand All @@ -179,6 +213,15 @@ def main() -> None:
parser.add_argument("--dataset-dir", type=Path, default=DATASET_DIR, help="Directory to save example videos to.")
parser.add_argument("--video-path", type=str, default="", help="Full path to video to run on. Overrides `--video`.")
parser.add_argument("--no-segment", action="store_true", help="Don't run person segmentation.")
parser.add_argument(
"--model",
type=str,
default="heavy",
choices=["lite", "full", "heavy"],
help="The mediapipe model to use (see https://developers.google.com/mediapipe/solutions/vision/pose_landmarker).",
)
parser.add_argument("--model-dir", type=Path, default=MODEL_DIR, help="Directory to save downloaded model to.")
parser.add_argument("--model-path", type=str, default="", help="Full path of mediapipe model. Overrides `--model`.")
parser.add_argument(
"--max-frame",
type=int,
Expand Down Expand Up @@ -206,9 +249,13 @@ def main() -> None:

video_path = args.video_path # type: str
if not video_path:
video_path = get_downloaded_path(args.dataset_dir, args.video)
video_path = get_downloaded_video_path(args.dataset_dir, args.video)

model_path = args.model_path # type: str
if not args.model_path:
model_path = get_downloaded_model_path(args.model_dir, args.model)

track_pose(video_path, segment=not args.no_segment, max_frame_count=args.max_frame)
track_pose(video_path, model_path, segment=not args.no_segment, max_frame_count=args.max_frame)

rr.script_teardown(args)

Expand Down
3 changes: 1 addition & 2 deletions examples/python/human_pose_tracking/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# no 3.12 version yet (https://pypi.org/project/mediapipe/)
# 0.10.10 no longer supports the legacy Pose model: https://github.com/rerun-io/rerun/issues/5859
mediapipe==0.10.9 ; python_version <= '3.11'
mediapipe==0.10.11 ; python_version <= '3.11'

numpy
opencv-python>4.6 # Avoid opencv-4.6 since it rotates images incorrectly (https://github.com/opencv/opencv/issues/22088)
Expand Down

0 comments on commit a8cf9ec

Please sign in to comment.