-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d461382
commit a6e8d96
Showing
2 changed files
with
392 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,10 @@ | ||
.ipynb_checkpoints/* | ||
__pycache__/ | ||
neo_albert_model/* | ||
alber_infence.csv | ||
alber_infence.csv | ||
detectron/demo/yolov8n-cls.pt | ||
detectron/demo/yolov8n-pose.pt | ||
detectron/demo/yolov8n-seg.pt | ||
detectron/demo/yolov8n.pt | ||
detectron/demo/0_0.jpg | ||
detectron/demo/sam_vit_b_01ec64.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,385 @@ | ||
|
||
from pytorch_model_factory import TorchModelFactory | ||
from PIL import Image | ||
from typing import Any, Dict | ||
import sys | ||
sys.path.append("..") | ||
from detectron2.structures import Instances | ||
import torch | ||
import threading | ||
import gc | ||
from ultralytics import YOLO,solutions | ||
from ultralytics.utils.plotting import Annotator, colors | ||
from detectron2.config import get_cfg | ||
import cv2 | ||
|
||
class YOLOPredictor: | ||
_instances = {} | ||
_lock = threading.Lock() | ||
|
||
def __new__(cls, cfg=None): | ||
if cfg is None: | ||
raise ValueError("Configuration must be provided") | ||
|
||
with cls._lock: | ||
if cfg.TASK_TYPE not in cls._instances: | ||
cls._instances[cfg.TASK_TYPE] = super(YOLOPredictor, cls).__new__(cls) | ||
cls._instances[cfg.TASK_TYPE]._initialize(cfg) | ||
return cls._instances[cfg.TASK_TYPE] | ||
|
||
def _initialize(self, cfg): | ||
self.cfg = cfg | ||
self.crop_enable = False | ||
|
||
if cfg.TASK_TYPE == "classification": | ||
print("classification") | ||
self.model = YOLO("yolov8n-cls.pt") | ||
elif cfg.TASK_TYPE == "detect": | ||
print("detect") | ||
self.model = YOLO("yolov8n.pt") | ||
self.dist_obj = solutions.DistanceCalculation(names=self.model.names, view_img=False) | ||
self.crop_enable = True | ||
elif cfg.TASK_TYPE == "pose": | ||
print("pose") | ||
self.model = YOLO("yolov8n-pose.pt") | ||
elif cfg.TASK_TYPE == "obb": | ||
print("obb") | ||
self.model = YOLO("yolov8n-obb.pt") | ||
elif cfg.TASK_TYPE == "instance": | ||
print("instance") | ||
self.model = YOLO("yolov8n-seg.pt") | ||
else: | ||
print("detect") | ||
self.model = YOLO("yolov8n.pt") | ||
|
||
@classmethod | ||
def get_instance(cls, task_type): | ||
with cls._lock: | ||
return cls._instances.get(task_type) | ||
|
||
def set_dis_obj(self,a,b): | ||
if self.dist_obj is not None: | ||
self.dist_obj.selected_boxes = {} | ||
self.dist_obj.selected_boxes = {a:None,b:None} | ||
print("set_dis_obj",self.dist_obj.selected_boxes) | ||
|
||
# def __new__(cls, cfg=None): | ||
# if cls._instance is None: | ||
# with cls._lock: | ||
# if cls._instance is None: | ||
# cls._instance = super(YOLOPredictor, cls).__new__(cls) | ||
# cls._instance._initialize(cfg) | ||
# return cls._instance | ||
|
||
# def _initialize(self, cfg=None): | ||
# self.model = TorchModelFactory.create_yolo_detect_model() | ||
|
||
def __call__(self, image): | ||
""" | ||
Args: | ||
image (PIL image): an image of shape (H, W, C) (in BGR order). | ||
Returns: | ||
predictions (dict): | ||
the output of the model for one image only. | ||
See :doc:`/tutorials/models` for details about the format. | ||
""" | ||
|
||
if self.model is None: | ||
return None | ||
|
||
if not isinstance(image,list): | ||
image = [image] | ||
|
||
predictions = self.model(image) | ||
return self._post_processor(predictions) | ||
|
||
def _video_processor(self,video_path,callback=None): | ||
|
||
cap = cv2.VideoCapture(video_path) | ||
fps = int(cap.get(cv2.CAP_PROP_FPS)) | ||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | ||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | ||
output_path = get_file_path_without_extension(video_path)+"after_inference.mp4" | ||
print("track:",output_path) | ||
video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)) | ||
# Loop through the video frames | ||
frame_count = 0 | ||
while cap.isOpened(): | ||
# Read a frame from the video | ||
success, frame = cap.read() | ||
frame_count+=1 | ||
if success: | ||
|
||
annotated_frame,tracks = callback(frame) | ||
if annotated_frame is None: | ||
continue | ||
|
||
if tracks is not None: | ||
output,_ = self._post_processor(tracks) | ||
# Display the annotated frame | ||
# cv2.imshow("YOLOv8 Tracking", annotated_frame) | ||
video.write(annotated_frame) | ||
if frame_count % fps == 0: | ||
yield output,annotated_frame, None,None | ||
|
||
# Break the loop if 'q' is pressed | ||
if cv2.waitKey(1) & 0xFF == ord("q"): | ||
break | ||
else: | ||
# Break the loop if the end of the video is reached | ||
break | ||
|
||
# Release the video capture object and close the display window | ||
video.release() | ||
cap.release() | ||
# cv2.destroyAllWindows() | ||
yield output,annotated_frame, output_path,None | ||
|
||
def track(self,video_path): | ||
def do_track(frame): | ||
# Run YOLOv8 tracking on the frame, persisting tracks between frames | ||
tracks = self.model.track(frame, persist=True) | ||
|
||
# Visualize the results on the frame | ||
return tracks[0].plot(),tracks | ||
|
||
yield from self._video_processor(video_path,do_track) | ||
|
||
def track_with_seg(self,video_path): | ||
def do_track(frame): | ||
annotator = Annotator(frame, line_width=2) | ||
results = self.model.track(frame, persist=True) | ||
if results[0].boxes.id is not None and results[0].masks is not None: | ||
masks = results[0].masks.xy | ||
track_ids = results[0].boxes.id.int().cpu().tolist() | ||
|
||
for mask, track_id in zip(masks, track_ids): | ||
annotator.seg_bbox(mask=mask, mask_color=colors(track_id, True), track_label=str(track_id)) | ||
return frame,results | ||
|
||
yield from self._video_processor(video_path,do_track) | ||
|
||
def counting(self,video_path,region_points=None): | ||
# Init Object Counter | ||
counter = solutions.ObjectCounter( | ||
view_img=False, | ||
reg_pts=region_points, | ||
classes_names=self.model.names, | ||
draw_tracks=True, | ||
line_thickness=2, | ||
) | ||
|
||
def do_count(frame): | ||
tracks = self.model.track(frame, persist=True, show=False) | ||
return counter.start_counting(frame, tracks),tracks | ||
|
||
yield from self._video_processor(video_path,do_count) | ||
|
||
def crop(self,frame,results): | ||
boxes = results[0].boxes.xyxy.cpu().tolist() | ||
clss = results[0].boxes.cls.cpu().tolist() | ||
annotator = Annotator(frame, line_width=2, example=self.model.names) | ||
pil_images = [] | ||
if boxes is not None: | ||
for box, cls in zip(boxes, clss): | ||
annotator.box_label(box, color=colors(int(cls), True), label=self.model.names[int(cls)]) | ||
crop_obj = frame[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] | ||
pil_images.append(Image.fromarray(crop_obj[..., ::-1])) | ||
|
||
return pil_images | ||
|
||
def gym_monitor(self,video_path,pose_type="pushup"): | ||
gym_object = solutions.AIGym( | ||
line_thickness=2, | ||
view_img=False, | ||
pose_type=pose_type, | ||
kpts_to_check=[6, 8, 10], | ||
) | ||
def do_gym(frame): | ||
try: | ||
tracks = self.model.track(frame, persist=True, show=False,verbose=False) | ||
return gym_object.start_counting(frame, tracks,frame_count=20),tracks | ||
except TypeError as e: | ||
# 捕获 AttributeError 异常,并打印错误信息 | ||
print(f"TypeError: {e}") | ||
# 或者你可以选择返回一个默认值或者执行其他恰当的操作 | ||
return None,None | ||
|
||
yield from self._video_processor(video_path,do_gym) | ||
|
||
def heatmap(self,video_path,classes_for_heatmap = None,region_points=None): | ||
heatmap_obj = solutions.Heatmap( | ||
colormap=cv2.COLORMAP_PARULA, | ||
view_img=False, | ||
shape="circle", | ||
classes_names=self.model.names, | ||
count_reg_pts=region_points, | ||
) | ||
|
||
def do_draw(frame): | ||
try: | ||
tracks = self.model.track(frame, persist=True, show=False, classes=classes_for_heatmap) | ||
if tracks is None: | ||
# 处理 tracks 为 None 的情况,这里可以抛出异常或者返回特定的值 | ||
raise ValueError("No tracks found") | ||
|
||
return heatmap_obj.generate_heatmap(frame, tracks),tracks | ||
|
||
except AttributeError as e: | ||
# 捕获 AttributeError 异常,并打印错误信息 | ||
print(f"AttributeError: {e}") | ||
# 或者你可以选择返回一个默认值或者执行其他恰当的操作 | ||
return None,None | ||
|
||
except ValueError as e: | ||
# 捕获 ValueError 异常,并打印错误信息 | ||
print(f"ValueError: {e}") | ||
# 或者返回一个默认的 heatmap 或者其他值 | ||
return None,None | ||
|
||
yield from self._video_processor(video_path,do_draw) | ||
|
||
def vision_eye(self,video_path,center_point=None): | ||
import math | ||
|
||
pixel_per_meter = 10 | ||
txt_color, txt_background, bbox_clr = ((0, 0, 0), (255, 255, 255), (255, 0, 255)) | ||
|
||
def vision_distance(frame): | ||
annotator = Annotator(frame, line_width=2) | ||
|
||
nonlocal center_point | ||
if center_point is None: | ||
height = frame.shape[0] | ||
center_point = (0, height) | ||
|
||
results = self.model.track(frame, persist=True) | ||
boxes = results[0].boxes.xyxy.cpu() | ||
|
||
if results[0].boxes.id is not None: | ||
track_ids = results[0].boxes.id.int().cpu().tolist() | ||
|
||
for box, track_id in zip(boxes, track_ids): | ||
annotator.box_label(box, label=str(track_id), color=bbox_clr) | ||
annotator.visioneye(box, center_point) | ||
|
||
x1, y1 = int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2) # Bounding box centroid | ||
|
||
distance = (math.sqrt((x1 - center_point[0]) ** 2 + (y1 - center_point[1]) ** 2)) / pixel_per_meter | ||
|
||
text_size, _ = cv2.getTextSize(f"Distance: {distance:.2f} m", cv2.FONT_HERSHEY_SIMPLEX, 1.2, 3) | ||
cv2.rectangle(frame, (x1, y1 - text_size[1] - 10), (x1 + text_size[0] + 10, y1), txt_background, -1) | ||
cv2.putText(frame, f"Distance: {distance:.2f} m", (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 1.2, txt_color, 3) | ||
return frame,results | ||
|
||
yield from self._video_processor(video_path,vision_distance) | ||
|
||
def speed(self,video_path): | ||
line_pts = [(0, 360), (1280, 360)] | ||
# Init speed-estimation obj | ||
speed_obj = solutions.SpeedEstimator( | ||
reg_pts=line_pts, | ||
names=self.model.names, | ||
view_img=False, | ||
) | ||
|
||
def do_speed_cal(frame): | ||
tracks = self.model.track(frame, persist=True, show=False) | ||
return speed_obj.estimate_speed(frame, tracks),tracks | ||
|
||
yield from self._video_processor(video_path,do_speed_cal) | ||
|
||
def distance(self,video_path): | ||
|
||
def do_dist_cal(frame): | ||
tracks = self.model.track(frame, persist=True, show=False, verbose=False) | ||
print("tracks:",len(tracks),self.dist_obj.trk_ids) | ||
return self.dist_obj.start_process(frame, tracks),tracks | ||
|
||
yield from self._video_processor(video_path,do_dist_cal) | ||
|
||
def queue_manager(self,video_path,queue_region=None): | ||
queue = solutions.QueueManager( | ||
classes_names=self.model.names, | ||
reg_pts=queue_region, | ||
line_thickness=3, | ||
fontsize=1.0, | ||
region_color=(255, 144, 31), | ||
) | ||
|
||
def queue(frame): | ||
tracks = self.model.track(frame, show=False, persist=True, verbose=False) | ||
queue.process_queue(frame, tracks) | ||
return frame,tracks | ||
|
||
yield from self._video_processor(video_path,queue) | ||
|
||
def _post_processor(self, output): | ||
# print("-------yolo------------\n", output) | ||
pil_images = [] | ||
|
||
result: Dict[str, Instances] = {} | ||
|
||
# TODO 只支持一个图片 | ||
for i, o in enumerate(output): | ||
im_bgr = o.plot() | ||
im_rgb = Image.fromarray(im_bgr[..., ::-1]) | ||
pil_images.append(im_rgb) | ||
if self.crop_enable: | ||
pil_images += self.crop(o.orig_img,output) | ||
|
||
inst_key = f"instances_{i}" | ||
result[inst_key] = Instances(o.orig_shape) | ||
|
||
if o.boxes is not None: | ||
result[inst_key].pred_boxes = o.boxes.xywh | ||
if o.boxes.id is not None: | ||
result[inst_key].trk_ids = o.boxes.id.int().cpu().tolist() | ||
|
||
if o.masks is not None: | ||
result[inst_key].pred_masks = o.masks.xyn | ||
|
||
if o.probs is not None: | ||
result[inst_key].scores = o.probs.top5 | ||
|
||
if o.keypoints is not None: | ||
result[inst_key].pred_keypoints = o.keypoints.xyn | ||
|
||
if o.obb is not None: | ||
result[inst_key].pred_obb = o.obb.xywhr | ||
|
||
return result, pil_images | ||
|
||
def release(self): | ||
# 删除模型对象 | ||
del self.model | ||
# 清除GPU缓存 | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
# 手动触发垃圾回收 | ||
gc.collect() | ||
|
||
def get_file_path_without_extension(file_path): | ||
import os | ||
# 获取文件所在的目录路径 | ||
directory = os.path.dirname(file_path) | ||
|
||
# 获取文件名(包括后缀) | ||
filename_with_extension = os.path.basename(file_path) | ||
|
||
# 分离文件名和后缀 | ||
filename, extension = os.path.splitext(filename_with_extension) | ||
|
||
# 拼接目录路径和文件名(不包括后缀) | ||
return os.path.join(directory, filename) | ||
|
||
# if __name__ == "__main__": | ||
# cfg = get_cfg() | ||
# cfg.TASK_TYPE = "detect" | ||
# f = YOLOPredictor(cfg) | ||
# # from PIL import Image | ||
# # img = Image.open("./test/test.png") | ||
# f.track("/home/neo/Videos/trafic.webm") | ||
|
||
|