forked from cvat-ai/cvat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added Segment Anything interactor for GPU/CPU (cvat-ai#6008)
Idea of the PR is to finish this one cvat-ai#5990 Deploy for GPU: ``./deploy_gpu.sh pytorch/facebookresearch/sam/nuclio/`` Deploy for CPU: ``./deploy_cpu.sh pytorch/facebookresearch/sam/nuclio/`` If you want to use GPU, be sure you setup docker for this [guide](https://github.com/NVIDIA/nvidia-docker/blob/master/README.md#quickstart). Resolved issue cvat-ai#5984 But the interface probably can be improved Co-authored-by: Alx-Wo <[email protected]>
- Loading branch information
Showing
6 changed files
with
242 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
serverless/pytorch/facebookresearch/sam/nuclio/function-gpu.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
metadata: | ||
name: pth.facebookresearch.sam.vit_h | ||
namespace: cvat | ||
annotations: | ||
name: Segment Anything | ||
version: 2 | ||
type: interactor | ||
spec: | ||
framework: pytorch | ||
min_pos_points: 1 | ||
min_neg_points: 0 | ||
animated_gif: https://raw.githubusercontent.com/opencv/cvat/develop/site/content/en/images/hrnet_example.gif | ||
help_message: The interactor allows to get a mask of an object using at least one positive, and any negative points inside it | ||
|
||
spec: | ||
description: Interactive object segmentation with Segment-Anything | ||
runtime: 'python:3.8' | ||
handler: main:handler | ||
eventTimeout: 30s | ||
env: | ||
- name: PYTHONPATH | ||
value: /opt/nuclio/sam | ||
|
||
build: | ||
image: cvat.pth.facebookresearch.sam.vit_h | ||
baseImage: ubuntu:22.04 | ||
|
||
directives: | ||
preCopy: | ||
# disable interactive frontend | ||
- kind: ENV | ||
value: DEBIAN_FRONTEND=noninteractive | ||
# set workdir | ||
- kind: WORKDIR | ||
value: /opt/nuclio/sam | ||
# install basic deps | ||
- kind: RUN | ||
value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6 | ||
# install sam deps | ||
- kind: RUN | ||
value: pip3 install torch torchvision torchaudio opencv-python pycocotools matplotlib onnxruntime onnx | ||
# install sam code | ||
- kind: RUN | ||
value: pip3 install git+https://github.com/facebookresearch/segment-anything.git | ||
# download sam weights | ||
- kind: RUN | ||
value: curl -O https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth | ||
# map pip3 and python3 to pip and python | ||
- kind: RUN | ||
value: ln -s /usr/bin/pip3 /usr/local/bin/pip && ln -s /usr/bin/python3 /usr/bin/python | ||
triggers: | ||
myHttpTrigger: | ||
maxWorkers: 1 | ||
kind: 'http' | ||
workerAvailabilityTimeoutMilliseconds: 10000 | ||
attributes: | ||
maxRequestBodySize: 33554432 # 32MB | ||
resources: | ||
limits: | ||
nvidia.com/gpu: 1 | ||
|
||
platform: | ||
attributes: | ||
restartPolicy: | ||
name: always | ||
maximumRetryCount: 3 | ||
mountMode: volume |
68 changes: 68 additions & 0 deletions
68
serverless/pytorch/facebookresearch/sam/nuclio/function.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
metadata: | ||
name: pth.facebookresearch.sam.vit_h | ||
namespace: cvat | ||
annotations: | ||
name: Segment Anything | ||
version: 2 | ||
type: interactor | ||
spec: | ||
framework: pytorch | ||
min_pos_points: 1 | ||
min_neg_points: 0 | ||
animated_gif: https://raw.githubusercontent.com/opencv/cvat/develop/site/content/en/images/hrnet_example.gif | ||
help_message: The interactor allows to get a mask of an object using at least one positive, and any negative points inside it | ||
|
||
spec: | ||
description: Interactive object segmentation with Segment-Anything | ||
runtime: 'python:3.8' | ||
handler: main:handler | ||
eventTimeout: 30s | ||
env: | ||
- name: PYTHONPATH | ||
value: /opt/nuclio/sam | ||
|
||
build: | ||
image: cvat.pth.facebookresearch.sam.vit_h | ||
baseImage: ubuntu:22.04 | ||
|
||
directives: | ||
preCopy: | ||
# disable interactive frontend | ||
- kind: ENV | ||
value: DEBIAN_FRONTEND=noninteractive | ||
# set workdir | ||
- kind: WORKDIR | ||
value: /opt/nuclio/sam | ||
# install basic deps | ||
- kind: RUN | ||
value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6 | ||
# install sam deps | ||
- kind: RUN | ||
value: pip3 install torch torchvision torchaudio opencv-python pycocotools matplotlib onnxruntime onnx | ||
# install sam code | ||
- kind: RUN | ||
value: pip3 install git+https://github.com/facebookresearch/segment-anything.git | ||
# download sam weights | ||
- kind: RUN | ||
value: curl -O https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth | ||
# map pip3 and python3 to pip and python | ||
- kind: RUN | ||
value: ln -s /usr/bin/pip3 /usr/local/bin/pip && ln -s /usr/bin/python3 /usr/bin/python | ||
triggers: | ||
myHttpTrigger: | ||
maxWorkers: 2 | ||
kind: 'http' | ||
workerAvailabilityTimeoutMilliseconds: 10000 | ||
attributes: | ||
maxRequestBodySize: 33554432 # 32MB | ||
|
||
platform: | ||
attributes: | ||
restartPolicy: | ||
name: always | ||
maximumRetryCount: 3 | ||
mountMode: volume |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import json | ||
import base64 | ||
from PIL import Image | ||
import io | ||
from model_handler import ModelHandler | ||
|
||
def init_context(context): | ||
context.logger.info("Init context... 0%") | ||
model = ModelHandler() | ||
context.user_data.model = model | ||
context.logger.info("Init context...100%") | ||
|
||
def handler(context, event): | ||
context.logger.info("call handler") | ||
data = event.body | ||
pos_points = data["pos_points"] | ||
neg_points = data["neg_points"] | ||
buf = io.BytesIO(base64.b64decode(data["image"])) | ||
image = Image.open(buf) | ||
image = image.convert("RGB") # to make sure image comes in RGB | ||
mask, polygon = context.user_data.model.handle(image, pos_points, neg_points) | ||
return context.Response(body=json.dumps({ | ||
'points': polygon, | ||
'mask': mask.tolist(), | ||
}), | ||
headers={}, | ||
content_type='application/json', | ||
status_code=200 | ||
) |
68 changes: 68 additions & 0 deletions
68
serverless/pytorch/facebookresearch/sam/nuclio/model_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import numpy as np | ||
import cv2 | ||
import torch | ||
from segment_anything import sam_model_registry, SamPredictor | ||
|
||
def convert_mask_to_polygon(mask): | ||
contours = None | ||
if int(cv2.__version__.split('.')[0]) > 3: | ||
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)[0] | ||
else: | ||
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)[1] | ||
|
||
contours = max(contours, key=lambda arr: arr.size) | ||
if contours.shape.count(1): | ||
contours = np.squeeze(contours) | ||
if contours.size < 3 * 2: | ||
raise Exception('Less then three point have been detected. Can not build a polygon.') | ||
|
||
polygon = [] | ||
for point in contours: | ||
polygon.append([int(point[0]), int(point[1])]) | ||
|
||
return polygon | ||
|
||
class ModelHandler: | ||
def __init__(self): | ||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
self.sam_checkpoint = "/opt/nuclio/sam/sam_vit_h_4b8939.pth" | ||
self.model_type = "vit_h" | ||
self.latest_image = None | ||
self.latest_low_res_masks = None | ||
sam_model = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint) | ||
sam_model.to(device=self.device) | ||
self.predictor = SamPredictor(sam_model) | ||
|
||
def handle(self, image, pos_points, neg_points): | ||
# latest image is kept in memory because function is always run-time after startup | ||
# we use to avoid computing emeddings twice for the same image | ||
is_the_same_image = self.latest_image is not None and np.array_equal(np.array(image), self.latest_image) | ||
if not is_the_same_image: | ||
self.latest_low_res_masks = None | ||
numpy_image = np.array(image) | ||
self.predictor.set_image(numpy_image) | ||
self.latest_image = numpy_image | ||
# we assume that pos_points and neg_points are of type: | ||
# np.array[[x, y], [x, y], ...] | ||
input_points = np.array(pos_points) | ||
input_labels = np.array([1] * len(pos_points)) | ||
|
||
if len(neg_points): | ||
input_points = np.concatenate([input_points, neg_points], axis=0) | ||
input_labels = np.concatenate([input_labels, np.array([0] * len(neg_points))], axis=0) | ||
|
||
masks, _, low_res_masks = self.predictor.predict( | ||
point_coords=input_points, | ||
point_labels=input_labels, | ||
mask_input = self.latest_low_res_masks, | ||
multimask_output=False | ||
) | ||
self.latest_low_res_masks = low_res_masks | ||
object_mask = np.array(masks[0], dtype=np.uint8) | ||
cv2.normalize(object_mask, object_mask, 0, 255, cv2.NORM_MINMAX) | ||
polygon = convert_mask_to_polygon(object_mask) | ||
return object_mask, polygon |