Skip to content

Commit

Permalink
add grounding-dino
Browse files Browse the repository at this point in the history
  • Loading branch information
yamy-cheng committed Apr 26, 2023
1 parent 4ea9d59 commit 4e8320f
Show file tree
Hide file tree
Showing 9 changed files with 394 additions and 236 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ assets/*gif
*.pyc
debug
cym_utils
# tutorial
126 changes: 78 additions & 48 deletions SegTracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import torch
from tool.segmentor import Segmentor
from tool.detector import Detector

import cv2
import os
from PIL import Image
Expand All @@ -27,25 +29,26 @@


class SegTracker():
def __init__(self,segtracker_args, sam_args,aot_args) -> None:
def __init__(self,segtracker_args, sam_args, aot_args) -> None:
"""
Initialize SAM and AOT.
"""
self.sam = Segmentor(sam_args)
self.tracker = get_aot(aot_args)
self.detector = Detector(self.sam.device)
self.sam_gap = segtracker_args['sam_gap']
self.min_area = segtracker_args['min_area']
self.max_obj_num = segtracker_args['max_obj_num']
self.min_new_obj_iou = segtracker_args['min_new_obj_iou']
self.reference_objs_list = []
self.object_idx = 1
self.origin_merged_mask = None # init with 0 or segment-everthing
self.refined_merged_mask = None # interactively refine by user
self.first_frame_mask = None

# debug
self.everything_points = []
self.everything_labels = []

def seg(self,frame):
'''
Arguments:
Expand Down Expand Up @@ -82,9 +85,17 @@ def seg(self,frame):
self.origin_merged_mask[self.origin_merged_mask==id] = self.object_idx
self.object_idx += 1

self.refined_merged_mask = self.origin_merged_mask
self.first_frame_mask = self.origin_merged_mask
return self.origin_merged_mask


def update_origin_merged_mask(self, updated_merged_mask):
self.origin_merged_mask = updated_merged_mask
self.object_idx += 1

def reset_origin_merged_mask(self):
self.origin_merged_mask = None
self.object_idx = 1

def add_reference(self,frame,mask,frame_step=0):
'''
Add objects in a mask for tracking.
Expand Down Expand Up @@ -155,15 +166,15 @@ def seg_acc_bbox(self, origin_frame: np.ndarray, bbox: np.ndarray,):

# get interactive_mask
interactive_mask = self.sam.segment_with_box(origin_frame, bbox)[0]
self.refined_merged_mask = self.add_mask(interactive_mask)
refined_merged_mask = self.add_mask(interactive_mask)

# draw mask
masked_frame = draw_mask(origin_frame.copy(), self.refined_merged_mask)
masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)

# draw bbox
masked_frame = cv2.rectangle(masked_frame, bbox[0], bbox[1], (0, 0, 255))

return self.refined_merged_mask, masked_frame
return refined_merged_mask, masked_frame

def refine_first_frame_click(self, origin_frame: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
'''
Expand All @@ -173,10 +184,10 @@ def refine_first_frame_click(self, origin_frame: np.ndarray, points:np.ndarray,
# get interactive_mask
interactive_mask, logit, outline = self.sam.segment_with_click(origin_frame, points, labels, multimask)

self.refined_merged_mask = self.add_mask(interactive_mask)
refined_merged_mask = self.add_mask(interactive_mask)

# draw mask
masked_frame = draw_mask(origin_frame.copy(), self.refined_merged_mask)
masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)

# draw points
# self.everything_labels = np.array(self.everything_labels).astype(np.int64)
Expand All @@ -188,60 +199,79 @@ def refine_first_frame_click(self, origin_frame: np.ndarray, points:np.ndarray,
# draw outline
masked_frame = np.where(outline > 0, outline, masked_frame)

return self.refined_merged_mask, masked_frame
return refined_merged_mask, masked_frame

def add_mask(self, interactive_mask):

def add_mask(self, interactive_mask, cover_origin_objects=True, single_object=True):
# if cover_origin_objects == Ture: interactive_mask will cover original object
# if single_object == True: added mask is belong to single object
if not cover_origin_objects:
empty_mask = np.where(self.origin_merged_mask == 0, 1, 0)
interactive_mask = interactive_mask * empty_mask

if self.origin_merged_mask is None:
self.origin_merged_mask = np.zeros(interactive_mask.shape,dtype=np.uint8)

refined_merged_mask = self.origin_merged_mask.copy()
refined_merged_mask[interactive_mask > 0] = self.object_idx

if not single_object:
self.object_idx += 1

return refined_merged_mask

def detect_and_seg(self, origin_frame, grounding_caption, box_threshold, text_threshold):

# get annotated_frame and boxes
annotated_frame, boxes = self.detector.run_grounding(origin_frame, grounding_caption, box_threshold, text_threshold)
for i in range(len(boxes)):
bbox = boxes[i]
interactive_mask = self.sam.segment_with_box(origin_frame, bbox)[0]
refined_merged_mask = self.add_mask(interactive_mask)
self.update_origin_merged_mask(refined_merged_mask)

# reset origin_mask
self.reset_origin_merged_mask()

return refined_merged_mask, annotated_frame

if __name__ == '__main__':
from model_args import segtracker_args,sam_args,aot_args

# ------------ draw point test --------------------------
Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args)
# Seg_Tracker.restart_tracker()

# origin_frame = cv2.imread('/data2/cym/Seg_Tra_any/Segment-and-Track-Anything/debug/point.png')
# origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_BGR2RGB)

# merged_mask = Seg_Tracker.seg(origin_frame)
# cv2.imwrite('./debug/merged_mask.png', -1)

# # two positive point
# point = np.array([[250, 370], [300, 420], [480, 150]])
# label = np.array([1, 0, 1])

# prompt = {
# "prompt_type":["click"],
# "input_point":point,
# "input_label":label,
# "multimask_output":"True",
# }

# predicted_mask, masked_frame = Seg_Tracker.refine_first_frame_click(
# origin_frame=origin_frame,
# points=np.array(prompt["input_point"]),
# labels=np.array(prompt["input_label"]),
# multimask=prompt["multimask_output"],
# )

# masked_frame = Image.fromarray(masked_frame)
# masked_frame.save('./debug/masked_frame.png')

# ------------------ detect test ----------------------

Seg_Tracker.init_detector()
origin_frame = cv2.imread('/data2/cym/Seg_Tra_any/Segment-and-Track-Anything/debug/point.png')
origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_BGR2RGB)
grounding_caption = "swan.water"
box_threshold = 0.25
text_threshold = 0.25

merged_mask = Seg_Tracker.seg(origin_frame)
cv2.imwrite('./debug/merged_mask.png', -1)

# one positive point
# point = np.array([[300, 420]])
# label = np.array([0])

# two positive point
point = np.array([[250, 370], [300, 420], [480, 150]])
label = np.array([1, 0, 1])

prompt = {
"prompt_type":["click"],
"input_point":point,
"input_label":label,
"multimask_output":"True",
}

predicted_mask, masked_frame = Seg_Tracker.refine_first_frame_click(
origin_frame=origin_frame,
points=np.array(prompt["input_point"]),
labels=np.array(prompt["input_label"]),
multimask=prompt["multimask_output"],
)

masked_frame = Image.fromarray(masked_frame)
masked_frame.save('./debug/masked_frame.png')
predicted_mask, annotated_frame = Seg_Tracker.detect_and_seg(origin_frame, grounding_caption, box_threshold, text_threshold)
masked_frame = draw_mask(annotated_frame, predicted_mask)
origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_RGB2BGR)

cv2.imwrite('./debug/masked_frame.png', masked_frame)
cv2.imwrite('./debug/x.png', annotated_frame)
Loading

0 comments on commit 4e8320f

Please sign in to comment.