Skip to content

Commit

Permalink
support custom video tracking demo with local gd1.0 model
Browse files Browse the repository at this point in the history
  • Loading branch information
rentainhe committed Sep 5, 2024
1 parent 834de44 commit 379e35c
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 0 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and speci
python grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py
```

If you are not convenient to use huggingface demo, you can also run tracking demo with local grounding dino model with the following scripts:

```bash
python grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py
```

### Grounded SAM 2 Video Object Tracking Demo with Custom Video Input (with Grounding DINO 1.5 & 1.6)

Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and specify their custom text prompts for grounding and tracking with Grounding DINO 1.5 and SAM 2 by using the following scripts:
Expand Down
220 changes: 220 additions & 0 deletions grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import os
import cv2
import torch
import numpy as np
import supervision as sv
from torchvision.ops import box_convert
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images

"""
Hyperparam for Ground and Tracking
"""
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
VIDEO_PATH = "./assets/hippopotamus.mp4"
TEXT_PROMPT = "hippopotamus."
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

"""
Step 1: Environment settings and model initialization for Grounding DINO and SAM 2
"""
# build grounding dino model from local path
grounding_model = load_model(
model_config_path=GROUNDING_DINO_CONFIG,
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
device=DEVICE
)


# init sam image predictor and video predictor model
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
image_predictor = SAM2ImagePredictor(sam2_image_model)


"""
Custom video input directly using video files
"""
video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
print(video_info)
frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)

# saving video to frames
source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
source_frames.mkdir(parents=True, exist_ok=True)

with sv.ImageSink(
target_dir_path=source_frames,
overwrite=True,
image_name_pattern="{:05d}.jpg"
) as sink:
for frame in tqdm(frame_generator, desc="Saving Video Frames"):
sink.save_image(frame)

# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# init video predictor state
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)

ann_frame_idx = 0 # the frame index we interact with
"""
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
"""

# prompt grounding dino to get the box coordinates on specific frame
img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
image_source, image = load_image(img_path)

boxes, confidences, labels = predict(
model=grounding_model,
image=image,
caption=TEXT_PROMPT,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD,
)

# process the box prompt for SAM 2
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
confidences = confidences.numpy().tolist()
class_names = labels

print(input_boxes)

# prompt SAM image predictor to get the mask for the object
image_predictor.set_image(image_source)

# process the detection results
OBJECTS = class_names

print(OBJECTS)

# FIXME: figure how does this influence the G-DINO model
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# prompt SAM 2 image predictor to get the mask for the object
masks, scores, logits = image_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
# convert the mask shape to (n, H, W)
if masks.ndim == 4:
masks = masks.squeeze(1)

"""
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
"""

assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"

# If you are using point prompts, we uniformly sample positive points based on the mask
if PROMPT_TYPE_FOR_VIDEO == "point":
# sample the positive points from mask for each objects
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)

for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
labels = np.ones((points.shape[0]), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
points=points,
labels=labels,
)
# Using box prompt
elif PROMPT_TYPE_FOR_VIDEO == "box":
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
box=box,
)
# Using mask prompt is a more straightforward way
elif PROMPT_TYPE_FOR_VIDEO == "mask":
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
labels = np.ones((1), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
mask=mask
)
else:
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")


"""
Step 4: Propagate the video predictor to get the segmentation results for each frame
"""
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}

"""
Step 5: Visualize the segment results across the video and save them
"""

if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
os.makedirs(SAVE_TRACKING_RESULTS_DIR)

ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}

for frame_idx, segments in video_segments.items():
img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))

object_ids = list(segments.keys())
masks = list(segments.values())
masks = np.concatenate(masks, axis=0)

detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
mask=masks, # (n, h, w)
class_id=np.array(object_ids, dtype=np.int32),
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)


"""
Step 6: Convert the annotated frames to video
"""

create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)

0 comments on commit 379e35c

Please sign in to comment.