Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
neoguojing committed Aug 25, 2024
1 parent 14f7a32 commit 42b44fa
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 221 deletions.
6 changes: 0 additions & 6 deletions detectron/demo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,9 @@
from detectron2.utils.logger import setup_logger
setup_logger()
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.utils.visualizer import ColorMode
import detectron2.data.transforms as T
from predictor import InferenceBase
import torch
import torchvision.transforms as transforms
from PIL import Image
from detectron2.data.detection_utils import pil_image_handler

Expand Down
39 changes: 0 additions & 39 deletions detectron/demo/model_manager.py

This file was deleted.

2 changes: 1 addition & 1 deletion detectron/demo/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False,device="cp
self.cpu_device = torch.device("cpu")
self.instance_mode = instance_mode
self.cfg = cfg
self.cfg.MODEL.DEVICE = device
self.cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

self.parallel = parallel

Expand Down
193 changes: 23 additions & 170 deletions detectron/demo/sam_everything.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,118 +10,13 @@
import threading
import cv2
import os
from scheduler.sched import ModelWrapper,LRUModelScheduler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def seg_with_promp(imput_image,point_coords=None,box=None):
if isinstance(imput_image, Image.Image):
imput_image = pil_image_to_numpy(imput_image)
point_labels = None
if point_coords is not None:
point_labels = np.ones(point_coords.shape[0])
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth").to(device)
predictor = SamPredictor(sam)
predictor.set_image(imput_image)

masks = None

if box is not None:
masks, _, _ = predictor.predict(box=box)
elif point_coords is not None and point_labels is not None:
masks, _, _ = predictor.predict(point_coords=point_coords,point_labels=point_labels)
print("seg_with_promp:",masks.shape)
pil_images = draw_bitmask(imput_image,masks)
return masks,pil_images

def seg_all(imput_image):
if isinstance(imput_image, Image.Image):
imput_image = pil_image_to_numpy(imput_image)

sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(imput_image)
pil_images = draw_bitmask(imput_image,masks)
# pil_images = draw_polygon(imput_image,masks)
# pil_images = draw_bitmask_split(imput_image,masks)
return masks,pil_images

# 为每个二值掩码生成一张图片
def draw_bitmask_split(np_image,masks):
for i,obj in enumerate(masks):
print("segmentation:",obj["segmentation"].shape)
view = Visualizer(np_image)
view.draw_binary_mask(obj["segmentation"])
vis_image = view.get_output()
pil_images = visimage_to_pil([vis_image],idx=i)
return pil_images

# 绘制二值掩码
def draw_bitmask(np_image,masks):
view = Visualizer(np_image)
for obj in masks:
if "segmentation" in obj:
print("segmentation:",obj["segmentation"].shape)
view.draw_binary_mask(obj["segmentation"])
else:
view.draw_binary_mask(obj)

vis_image = view.get_output()
pil_images = visimage_to_pil([vis_image])
return pil_images

# 绘制多边形掩码
def draw_polygon(np_image,masks):
view = Visualizer(np_image)
for obj in masks:
polygon = bitmask_to_polygon(obj["segmentation"])
view.draw_polygon(polygon,"k")
vis_image = view.get_output()
pil_images = visimage_to_pil([vis_image])
return pil_images


# 二值掩码转换为多边形掩码
def bitmask_to_polygon(mask):
col_mask = np.asfortranarray(mask)
contours = measure.find_contours(col_mask,0.5)
print("contours------",contours.shape)
for i,contour in enumerate(contours):
contour = np.flip(contour, axis=1)
print(f"polygon_{i}",contour.shape)
# polygon = contour.ravel().tolist()
# print(f"polygon_{i}",polygon)
return contour

# VIS图片转换为pil
def visimage_to_pil(visimages,need_save=True,idx=0):
pil_images = []
for i,visimage in enumerate(visimages):
visualized_image = visimage.get_image()
# [:, :, ::-1]
pil_image = Image.fromarray(visualized_image)
if need_save:
pil_image.save(f"{idx}_{i}.jpg")
pil_images.append(pil_image)
return pil_images

def image_to_mask(image, threshold=128):
# 将图像转换为灰度图像
if image.mode != 'L':
image = image.convert('L')

# 将像素值映射到二进制值
mask_array = np.array(image) > threshold

# 创建一个与原始图像大小相同的数组,用映射后的二进制值填充
mask_image = Image.fromarray(np.uint8(mask_array) * 255)

return mask_image

class SamAnything:
class SamAnything(ModelWrapper):
_instance = None
_lock = threading.Lock()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_sched = LRUModelScheduler()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def __new__(cls, *args, **kwargs):
if cls._instance is None:
Expand Down Expand Up @@ -150,12 +45,12 @@ def seg_with_promp(self, input_image=None,video_dir=None, point_coords=None, box
pil_images = self.draw_bitmask(input_image, masks)
yield pil_images,None

def seg_all(self, iput_image):
if isinstance(iput_image, Image.Image):
iput_image = pil_image_to_numpy(iput_image)
def seg_all(self, input_image):
if isinstance(input_image, Image.Image):
input_image = pil_image_to_numpy(input_image)

masks = self.mask_generator.generate(iput_image)
pil_images = self.draw_bitmask(iput_image, masks)
masks = self.mask_generator.generate(input_image)
pil_images = self.draw_bitmask(input_image, masks)
yield pil_images,None

@staticmethod
Expand Down Expand Up @@ -221,6 +116,13 @@ def image_to_mask(image, threshold=128):
mask_array = np.array(image) > threshold
mask_image = Image.fromarray(np.uint8(mask_array) * 255)
return mask_image

def release(self):
import gc
# 删除模型对象
del self.sam
# 手动触发垃圾回收
gc.collect()


class SamAnything2:
Expand All @@ -241,9 +143,11 @@ def _initialize(self, checkpoint_path="./sam_vit_b_01ec64.pth"):
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.sam2_video_predictor import SAM2VideoPredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
self.video_predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-small")
self.mask_generator = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-small")

def seg_with_promp(self, input_image=None, video_dir=None,point_coords=None, box=None):
point_labels = None
Expand Down Expand Up @@ -310,15 +214,12 @@ def seg_with_promp(self, input_image=None, video_dir=None,point_coords=None, box


def seg_all(self, input_image):
if input_image is not None:
# if isinstance(input_image, Image.Image):
# input_image = pil_image_to_numpy(input_image)
if isinstance(input_image, Image.Image):
input_image = pil_image_to_numpy(input_image)

with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
self.predictor.set_image(input_image)
masks, _, _ = self.predictor.predict(multimask_output=False)
pil_images = self.draw_bitmask(input_image, masks)
yield pil_images,None
masks = self.mask_generator.generate(input_image)
pil_images = self.draw_bitmask(input_image, masks)
yield pil_images,None


def extract_frames(self,video_path):
Expand Down Expand Up @@ -350,17 +251,6 @@ def extract_frames(self,video_path):
cap.release()
return output_dir,frame_names,fps,frame_size

@staticmethod
def draw_bitmask_split(np_image, masks):
pil_images = []
for i, obj in enumerate(masks):
print("segmentation:", obj["segmentation"].shape)
view = Visualizer(np_image)
view.draw_binary_mask(obj["segmentation"])
vis_image = view.get_output()
pil_images.extend(SamAnything.visimage_to_pil([vis_image], idx=i))
return pil_images

@staticmethod
def draw_bitmask(np_image, masks,pil_image=True):
view = Visualizer(np_image)
Expand All @@ -378,44 +268,7 @@ def draw_bitmask(np_image, masks,pil_image=True):

return vis_image.get_image()

@staticmethod
def draw_polygon(np_image, masks):
view = Visualizer(np_image)
for obj in masks:
polygon = SamAnything.bitmask_to_polygon(obj["segmentation"])
view.draw_polygon(polygon, "k")
vis_image = view.get_output()
pil_images = SamAnything.visimage_to_pil([vis_image])
return pil_images

@staticmethod
def bitmask_to_polygon(mask):
col_mask = np.asfortranarray(mask)
contours = measure.find_contours(col_mask, 0.5)
print("contours------", len(contours))
for i, contour in enumerate(contours):
contour = np.flip(contour, axis=1)
print(f"polygon_{i}", contour.shape)
return contours

@staticmethod
def visimage_to_pil(visimages, need_save=True, idx=0):
pil_images = []
for i, visimage in enumerate(visimages):
visualized_image = visimage.get_image()
pil_image = Image.fromarray(visualized_image)
if need_save:
pil_image.save(f"{idx}_{i}.jpg")
pil_images.append(pil_image)
return pil_images

@staticmethod
def image_to_mask(image, threshold=128):
if image.mode != 'L':
image = image.convert('L')
mask_array = np.array(image) > threshold
mask_image = Image.fromarray(np.uint8(mask_array) * 255)
return mask_image

# if __name__ == "__main__":
# np_image = read_image("./test/face1.jpeg")
Expand Down
15 changes: 10 additions & 5 deletions detectron/scheduler/sched.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import collections
import weakref
import threading
from abc import ABC, abstractmethod

class ModelWrapper(ABC):

@abstractmethod
def release(self):
pass

class LRUModelScheduler:
_instance = None
Expand All @@ -25,7 +32,7 @@ def get_model(self, key):
self.cache.move_to_end(key)
return self.cache[key]

def put_model(self, key, model):
def put_model(self, key, model:ModelWrapper):
"""添加/更新模型到调度器"""
if key in self.cache:
# 如果模型已存在,则更新
Expand All @@ -38,12 +45,10 @@ def put_model(self, key, model):
# 添加新模型
self.cache[key] = weakref.ref(model, self._model_finalizer)

def _destroy_model(self, model):
def _destroy_model(self, model:ModelWrapper):
"""销毁模型(释放资源)"""
if model is not None:
# 执行模型资源释放的逻辑
del model # 实际销毁模型的代码可能因模型类型不同而有所不同
print(f"Model {model} destroyed")
model.release()

def _model_finalizer(self, weak_ref):
"""模型被销毁时的回调"""
Expand Down

0 comments on commit 42b44fa

Please sign in to comment.