Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
neoguojing committed Aug 8, 2024
1 parent d461382 commit a6e8d96
Show file tree
Hide file tree
Showing 2 changed files with 392 additions and 1 deletion.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
.ipynb_checkpoints/*
__pycache__/
neo_albert_model/*
alber_infence.csv
alber_infence.csv
detectron/demo/yolov8n-cls.pt
detectron/demo/yolov8n-pose.pt
detectron/demo/yolov8n-seg.pt
detectron/demo/yolov8n.pt
detectron/demo/0_0.jpg
detectron/demo/sam_vit_b_01ec64.pth
385 changes: 385 additions & 0 deletions detectron/demo/yolo_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,385 @@

from pytorch_model_factory import TorchModelFactory
from PIL import Image
from typing import Any, Dict
import sys
sys.path.append("..")
from detectron2.structures import Instances
import torch
import threading
import gc
from ultralytics import YOLO,solutions
from ultralytics.utils.plotting import Annotator, colors
from detectron2.config import get_cfg
import cv2

class YOLOPredictor:
_instances = {}
_lock = threading.Lock()

def __new__(cls, cfg=None):
if cfg is None:
raise ValueError("Configuration must be provided")

with cls._lock:
if cfg.TASK_TYPE not in cls._instances:
cls._instances[cfg.TASK_TYPE] = super(YOLOPredictor, cls).__new__(cls)
cls._instances[cfg.TASK_TYPE]._initialize(cfg)
return cls._instances[cfg.TASK_TYPE]

def _initialize(self, cfg):
self.cfg = cfg
self.crop_enable = False

if cfg.TASK_TYPE == "classification":
print("classification")
self.model = YOLO("yolov8n-cls.pt")
elif cfg.TASK_TYPE == "detect":
print("detect")
self.model = YOLO("yolov8n.pt")
self.dist_obj = solutions.DistanceCalculation(names=self.model.names, view_img=False)
self.crop_enable = True
elif cfg.TASK_TYPE == "pose":
print("pose")
self.model = YOLO("yolov8n-pose.pt")
elif cfg.TASK_TYPE == "obb":
print("obb")
self.model = YOLO("yolov8n-obb.pt")
elif cfg.TASK_TYPE == "instance":
print("instance")
self.model = YOLO("yolov8n-seg.pt")
else:
print("detect")
self.model = YOLO("yolov8n.pt")

@classmethod
def get_instance(cls, task_type):
with cls._lock:
return cls._instances.get(task_type)

def set_dis_obj(self,a,b):
if self.dist_obj is not None:
self.dist_obj.selected_boxes = {}
self.dist_obj.selected_boxes = {a:None,b:None}
print("set_dis_obj",self.dist_obj.selected_boxes)

# def __new__(cls, cfg=None):
# if cls._instance is None:
# with cls._lock:
# if cls._instance is None:
# cls._instance = super(YOLOPredictor, cls).__new__(cls)
# cls._instance._initialize(cfg)
# return cls._instance

# def _initialize(self, cfg=None):
# self.model = TorchModelFactory.create_yolo_detect_model()

def __call__(self, image):
"""
Args:
image (PIL image): an image of shape (H, W, C) (in BGR order).
Returns:
predictions (dict):
the output of the model for one image only.
See :doc:`/tutorials/models` for details about the format.
"""

if self.model is None:
return None

if not isinstance(image,list):
image = [image]

predictions = self.model(image)
return self._post_processor(predictions)

def _video_processor(self,video_path,callback=None):

cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_path = get_file_path_without_extension(video_path)+"after_inference.mp4"
print("track:",output_path)
video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
# Loop through the video frames
frame_count = 0
while cap.isOpened():
# Read a frame from the video
success, frame = cap.read()
frame_count+=1
if success:

annotated_frame,tracks = callback(frame)
if annotated_frame is None:
continue

if tracks is not None:
output,_ = self._post_processor(tracks)
# Display the annotated frame
# cv2.imshow("YOLOv8 Tracking", annotated_frame)
video.write(annotated_frame)
if frame_count % fps == 0:
yield output,annotated_frame, None,None

# Break the loop if 'q' is pressed
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
# Break the loop if the end of the video is reached
break

# Release the video capture object and close the display window
video.release()
cap.release()
# cv2.destroyAllWindows()
yield output,annotated_frame, output_path,None

def track(self,video_path):
def do_track(frame):
# Run YOLOv8 tracking on the frame, persisting tracks between frames
tracks = self.model.track(frame, persist=True)

# Visualize the results on the frame
return tracks[0].plot(),tracks

yield from self._video_processor(video_path,do_track)

def track_with_seg(self,video_path):
def do_track(frame):
annotator = Annotator(frame, line_width=2)
results = self.model.track(frame, persist=True)
if results[0].boxes.id is not None and results[0].masks is not None:
masks = results[0].masks.xy
track_ids = results[0].boxes.id.int().cpu().tolist()

for mask, track_id in zip(masks, track_ids):
annotator.seg_bbox(mask=mask, mask_color=colors(track_id, True), track_label=str(track_id))
return frame,results

yield from self._video_processor(video_path,do_track)

def counting(self,video_path,region_points=None):
# Init Object Counter
counter = solutions.ObjectCounter(
view_img=False,
reg_pts=region_points,
classes_names=self.model.names,
draw_tracks=True,
line_thickness=2,
)

def do_count(frame):
tracks = self.model.track(frame, persist=True, show=False)
return counter.start_counting(frame, tracks),tracks

yield from self._video_processor(video_path,do_count)

def crop(self,frame,results):
boxes = results[0].boxes.xyxy.cpu().tolist()
clss = results[0].boxes.cls.cpu().tolist()
annotator = Annotator(frame, line_width=2, example=self.model.names)
pil_images = []
if boxes is not None:
for box, cls in zip(boxes, clss):
annotator.box_label(box, color=colors(int(cls), True), label=self.model.names[int(cls)])
crop_obj = frame[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])]
pil_images.append(Image.fromarray(crop_obj[..., ::-1]))

return pil_images

def gym_monitor(self,video_path,pose_type="pushup"):
gym_object = solutions.AIGym(
line_thickness=2,
view_img=False,
pose_type=pose_type,
kpts_to_check=[6, 8, 10],
)
def do_gym(frame):
try:
tracks = self.model.track(frame, persist=True, show=False,verbose=False)
return gym_object.start_counting(frame, tracks,frame_count=20),tracks
except TypeError as e:
# 捕获 AttributeError 异常,并打印错误信息
print(f"TypeError: {e}")
# 或者你可以选择返回一个默认值或者执行其他恰当的操作
return None,None

yield from self._video_processor(video_path,do_gym)

def heatmap(self,video_path,classes_for_heatmap = None,region_points=None):
heatmap_obj = solutions.Heatmap(
colormap=cv2.COLORMAP_PARULA,
view_img=False,
shape="circle",
classes_names=self.model.names,
count_reg_pts=region_points,
)

def do_draw(frame):
try:
tracks = self.model.track(frame, persist=True, show=False, classes=classes_for_heatmap)
if tracks is None:
# 处理 tracks 为 None 的情况,这里可以抛出异常或者返回特定的值
raise ValueError("No tracks found")

return heatmap_obj.generate_heatmap(frame, tracks),tracks

except AttributeError as e:
# 捕获 AttributeError 异常,并打印错误信息
print(f"AttributeError: {e}")
# 或者你可以选择返回一个默认值或者执行其他恰当的操作
return None,None

except ValueError as e:
# 捕获 ValueError 异常,并打印错误信息
print(f"ValueError: {e}")
# 或者返回一个默认的 heatmap 或者其他值
return None,None

yield from self._video_processor(video_path,do_draw)

def vision_eye(self,video_path,center_point=None):
import math

pixel_per_meter = 10
txt_color, txt_background, bbox_clr = ((0, 0, 0), (255, 255, 255), (255, 0, 255))

def vision_distance(frame):
annotator = Annotator(frame, line_width=2)

nonlocal center_point
if center_point is None:
height = frame.shape[0]
center_point = (0, height)

results = self.model.track(frame, persist=True)
boxes = results[0].boxes.xyxy.cpu()

if results[0].boxes.id is not None:
track_ids = results[0].boxes.id.int().cpu().tolist()

for box, track_id in zip(boxes, track_ids):
annotator.box_label(box, label=str(track_id), color=bbox_clr)
annotator.visioneye(box, center_point)

x1, y1 = int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2) # Bounding box centroid

distance = (math.sqrt((x1 - center_point[0]) ** 2 + (y1 - center_point[1]) ** 2)) / pixel_per_meter

text_size, _ = cv2.getTextSize(f"Distance: {distance:.2f} m", cv2.FONT_HERSHEY_SIMPLEX, 1.2, 3)
cv2.rectangle(frame, (x1, y1 - text_size[1] - 10), (x1 + text_size[0] + 10, y1), txt_background, -1)
cv2.putText(frame, f"Distance: {distance:.2f} m", (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 1.2, txt_color, 3)
return frame,results

yield from self._video_processor(video_path,vision_distance)

def speed(self,video_path):
line_pts = [(0, 360), (1280, 360)]
# Init speed-estimation obj
speed_obj = solutions.SpeedEstimator(
reg_pts=line_pts,
names=self.model.names,
view_img=False,
)

def do_speed_cal(frame):
tracks = self.model.track(frame, persist=True, show=False)
return speed_obj.estimate_speed(frame, tracks),tracks

yield from self._video_processor(video_path,do_speed_cal)

def distance(self,video_path):

def do_dist_cal(frame):
tracks = self.model.track(frame, persist=True, show=False, verbose=False)
print("tracks:",len(tracks),self.dist_obj.trk_ids)
return self.dist_obj.start_process(frame, tracks),tracks

yield from self._video_processor(video_path,do_dist_cal)

def queue_manager(self,video_path,queue_region=None):
queue = solutions.QueueManager(
classes_names=self.model.names,
reg_pts=queue_region,
line_thickness=3,
fontsize=1.0,
region_color=(255, 144, 31),
)

def queue(frame):
tracks = self.model.track(frame, show=False, persist=True, verbose=False)
queue.process_queue(frame, tracks)
return frame,tracks

yield from self._video_processor(video_path,queue)

def _post_processor(self, output):
# print("-------yolo------------\n", output)
pil_images = []

result: Dict[str, Instances] = {}

# TODO 只支持一个图片
for i, o in enumerate(output):
im_bgr = o.plot()
im_rgb = Image.fromarray(im_bgr[..., ::-1])
pil_images.append(im_rgb)
if self.crop_enable:
pil_images += self.crop(o.orig_img,output)

inst_key = f"instances_{i}"
result[inst_key] = Instances(o.orig_shape)

if o.boxes is not None:
result[inst_key].pred_boxes = o.boxes.xywh
if o.boxes.id is not None:
result[inst_key].trk_ids = o.boxes.id.int().cpu().tolist()

if o.masks is not None:
result[inst_key].pred_masks = o.masks.xyn

if o.probs is not None:
result[inst_key].scores = o.probs.top5

if o.keypoints is not None:
result[inst_key].pred_keypoints = o.keypoints.xyn

if o.obb is not None:
result[inst_key].pred_obb = o.obb.xywhr

return result, pil_images

def release(self):
# 删除模型对象
del self.model
# 清除GPU缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 手动触发垃圾回收
gc.collect()

def get_file_path_without_extension(file_path):
import os
# 获取文件所在的目录路径
directory = os.path.dirname(file_path)

# 获取文件名(包括后缀)
filename_with_extension = os.path.basename(file_path)

# 分离文件名和后缀
filename, extension = os.path.splitext(filename_with_extension)

# 拼接目录路径和文件名(不包括后缀)
return os.path.join(directory, filename)

# if __name__ == "__main__":
# cfg = get_cfg()
# cfg.TASK_TYPE = "detect"
# f = YOLOPredictor(cfg)
# # from PIL import Image
# # img = Image.open("./test/test.png")
# f.track("/home/neo/Videos/trafic.webm")


0 comments on commit a6e8d96

Please sign in to comment.