forked from IDEA-Research/Grounded-SAM-2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support custom video tracking demo with local gd1.0 model
- Loading branch information
Showing
2 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
220 changes: 220 additions & 0 deletions
220
grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import os | ||
import cv2 | ||
import torch | ||
import numpy as np | ||
import supervision as sv | ||
from torchvision.ops import box_convert | ||
from pathlib import Path | ||
from tqdm import tqdm | ||
from PIL import Image | ||
from sam2.build_sam import build_sam2_video_predictor, build_sam2 | ||
from sam2.sam2_image_predictor import SAM2ImagePredictor | ||
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict | ||
from utils.track_utils import sample_points_from_masks | ||
from utils.video_utils import create_video_from_images | ||
|
||
""" | ||
Hyperparam for Ground and Tracking | ||
""" | ||
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py" | ||
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth" | ||
BOX_THRESHOLD = 0.35 | ||
TEXT_THRESHOLD = 0.25 | ||
VIDEO_PATH = "./assets/hippopotamus.mp4" | ||
TEXT_PROMPT = "hippopotamus." | ||
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4" | ||
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames" | ||
SAVE_TRACKING_RESULTS_DIR = "./tracking_results" | ||
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"] | ||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
""" | ||
Step 1: Environment settings and model initialization for Grounding DINO and SAM 2 | ||
""" | ||
# build grounding dino model from local path | ||
grounding_model = load_model( | ||
model_config_path=GROUNDING_DINO_CONFIG, | ||
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT, | ||
device=DEVICE | ||
) | ||
|
||
|
||
# init sam image predictor and video predictor model | ||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" | ||
model_cfg = "sam2_hiera_l.yaml" | ||
|
||
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) | ||
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint) | ||
image_predictor = SAM2ImagePredictor(sam2_image_model) | ||
|
||
|
||
""" | ||
Custom video input directly using video files | ||
""" | ||
video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info | ||
print(video_info) | ||
frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None) | ||
|
||
# saving video to frames | ||
source_frames = Path(SOURCE_VIDEO_FRAME_DIR) | ||
source_frames.mkdir(parents=True, exist_ok=True) | ||
|
||
with sv.ImageSink( | ||
target_dir_path=source_frames, | ||
overwrite=True, | ||
image_name_pattern="{:05d}.jpg" | ||
) as sink: | ||
for frame in tqdm(frame_generator, desc="Saving Video Frames"): | ||
sink.save_image(frame) | ||
|
||
# scan all the JPEG frame names in this directory | ||
frame_names = [ | ||
p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR) | ||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] | ||
] | ||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) | ||
|
||
# init video predictor state | ||
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR) | ||
|
||
ann_frame_idx = 0 # the frame index we interact with | ||
""" | ||
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates | ||
""" | ||
|
||
# prompt grounding dino to get the box coordinates on specific frame | ||
img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx]) | ||
image_source, image = load_image(img_path) | ||
|
||
boxes, confidences, labels = predict( | ||
model=grounding_model, | ||
image=image, | ||
caption=TEXT_PROMPT, | ||
box_threshold=BOX_THRESHOLD, | ||
text_threshold=TEXT_THRESHOLD, | ||
) | ||
|
||
# process the box prompt for SAM 2 | ||
h, w, _ = image_source.shape | ||
boxes = boxes * torch.Tensor([w, h, w, h]) | ||
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() | ||
confidences = confidences.numpy().tolist() | ||
class_names = labels | ||
|
||
print(input_boxes) | ||
|
||
# prompt SAM image predictor to get the mask for the object | ||
image_predictor.set_image(image_source) | ||
|
||
# process the detection results | ||
OBJECTS = class_names | ||
|
||
print(OBJECTS) | ||
|
||
# FIXME: figure how does this influence the G-DINO model | ||
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__() | ||
|
||
if torch.cuda.get_device_properties(0).major >= 8: | ||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
|
||
# prompt SAM 2 image predictor to get the mask for the object | ||
masks, scores, logits = image_predictor.predict( | ||
point_coords=None, | ||
point_labels=None, | ||
box=input_boxes, | ||
multimask_output=False, | ||
) | ||
# convert the mask shape to (n, H, W) | ||
if masks.ndim == 4: | ||
masks = masks.squeeze(1) | ||
|
||
""" | ||
Step 3: Register each object's positive points to video predictor with seperate add_new_points call | ||
""" | ||
|
||
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt" | ||
|
||
# If you are using point prompts, we uniformly sample positive points based on the mask | ||
if PROMPT_TYPE_FOR_VIDEO == "point": | ||
# sample the positive points from mask for each objects | ||
all_sample_points = sample_points_from_masks(masks=masks, num_points=10) | ||
|
||
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): | ||
labels = np.ones((points.shape[0]), dtype=np.int32) | ||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box( | ||
inference_state=inference_state, | ||
frame_idx=ann_frame_idx, | ||
obj_id=object_id, | ||
points=points, | ||
labels=labels, | ||
) | ||
# Using box prompt | ||
elif PROMPT_TYPE_FOR_VIDEO == "box": | ||
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1): | ||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box( | ||
inference_state=inference_state, | ||
frame_idx=ann_frame_idx, | ||
obj_id=object_id, | ||
box=box, | ||
) | ||
# Using mask prompt is a more straightforward way | ||
elif PROMPT_TYPE_FOR_VIDEO == "mask": | ||
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1): | ||
labels = np.ones((1), dtype=np.int32) | ||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( | ||
inference_state=inference_state, | ||
frame_idx=ann_frame_idx, | ||
obj_id=object_id, | ||
mask=mask | ||
) | ||
else: | ||
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts") | ||
|
||
|
||
""" | ||
Step 4: Propagate the video predictor to get the segmentation results for each frame | ||
""" | ||
video_segments = {} # video_segments contains the per-frame segmentation results | ||
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state): | ||
video_segments[out_frame_idx] = { | ||
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() | ||
for i, out_obj_id in enumerate(out_obj_ids) | ||
} | ||
|
||
""" | ||
Step 5: Visualize the segment results across the video and save them | ||
""" | ||
|
||
if not os.path.exists(SAVE_TRACKING_RESULTS_DIR): | ||
os.makedirs(SAVE_TRACKING_RESULTS_DIR) | ||
|
||
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)} | ||
|
||
for frame_idx, segments in video_segments.items(): | ||
img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx])) | ||
|
||
object_ids = list(segments.keys()) | ||
masks = list(segments.values()) | ||
masks = np.concatenate(masks, axis=0) | ||
|
||
detections = sv.Detections( | ||
xyxy=sv.mask_to_xyxy(masks), # (n, 4) | ||
mask=masks, # (n, h, w) | ||
class_id=np.array(object_ids, dtype=np.int32), | ||
) | ||
box_annotator = sv.BoxAnnotator() | ||
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections) | ||
label_annotator = sv.LabelAnnotator() | ||
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids]) | ||
mask_annotator = sv.MaskAnnotator() | ||
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) | ||
cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame) | ||
|
||
|
||
""" | ||
Step 6: Convert the annotated frames to video | ||
""" | ||
|
||
create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH) |