Skip to content

Commit

Permalink
fix some bug, about propose post-processing from the model to avoid i…
Browse files Browse the repository at this point in the history
…nputting fixed dimensions
  • Loading branch information
BooHwang committed Jun 1, 2023
1 parent fde186b commit 242e0fe
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 21 deletions.
40 changes: 26 additions & 14 deletions sam_trt_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
from utils.common import TrtModel
import os
import argparse
from utils import apply_coords, pre_processing
import matplotlib.pyplot as plt
from utils import apply_coords, pre_processing, mask_postprocessing, show_mask, show_points


if __name__ == "__main__":
parser = argparse.ArgumentParser("use tensorrt to inference segment anything model")
parser.add_argument("--img_path", type=str, default="images/truck.jpg", help="you want segment image")
parser.add_argument("--sam_engine_file", type=str, default="weights/sam_vit_h_4b8939.engine")
parser.add_argument("--sam_engine_file", type=str, default="weights/sam_default_prompt_mask.engine")
parser.add_argument("--embedding_engine_file", type=str, default="embedding_onnx/sam_default_embedding.engine")
parser.add_argument("--gpu_id", type=int, default=0, help="use which gpu to inference")
parser.add_argument("--batch_size", type=int, default=1, help="use batch size img to inference")
args = parser.parse_args()

image = cv2.imread(args.img_path)
orig_im_size = image.shape[:2]
img_inputs = pre_processing(image)
print(f'img input: {img_inputs.shape}')

Expand All @@ -49,35 +51,45 @@
onnx_coord = apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
onnx_orig_im_size = np.array(image.shape[:2], dtype=np.int32)

# print(image_embedding.shape)
# print(onnx_coord.shape)
# print(onnx_label.shape)
# print(onnx_mask_input.shape)
# print(onnx_has_mask_input.shape)
# print(onnx_orig_im_size.shape)

input = [image_embedding, onnx_coord, onnx_label, onnx_mask_input, onnx_has_mask_input, onnx_orig_im_size]
input = [image_embedding, onnx_coord, onnx_label, onnx_mask_input, onnx_has_mask_input]
shape_map = {'image_embeddings': image_embedding.shape,
'point_coords': onnx_coord.shape,
'point_labels': onnx_label.shape,
'mask_input': onnx_mask_input.shape,
'has_mask_input': onnx_has_mask_input.shape,
'orig_im_size': onnx_orig_im_size.shape}
'has_mask_input': onnx_has_mask_input.shape}

output = sam_inference(input, binding_shape_map=shape_map)

# print(output[0].shape)
# print(output[1].shape)
# print(output[2].shape)

low_res_logits = output[0].reshape(args.batch_size, -1).reshape(4, 256, 256)
scores = output[1].reshape(args.batch_size, -1)
masks = output[2].reshape(4, 1200, 1800)
scores = output[1].reshape(args.batch_size, -1).squeeze(0)

masks = mask_postprocessing(low_res_logits, orig_im_size, img_inputs.shape[2])
masks = masks.numpy().squeeze(0)
os.makedirs("results", exist_ok=True)
for i in range(masks.shape[0]):
# mask_image = show_mask(masks[i]*255)
cv2.imwrite(f"results/trt_mask{i}.png", masks[i]*255)
print(f"Generate results/trt_mask{i}.png")
# for i in range(masks.shape[0]):
# # mask_image = show_mask(masks[i]*255)
# cv2.imwrite(f"results/trt_mask{i}.png", masks[i]*255)
# print(f"Generate results/trt_mask{i}.png")

for i, (mask, score) in enumerate(zip(masks, scores)):
mask = mask > 0.0
plt.figure(figsize=(10,10))
plt.imshow(image[:, :, ::-1])
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.savefig(f"results/trt_mask{i}.png", bbox_inches='tight', pad_inches=0)
print(f"generate: results/trt_mask{i}.png")
# plt.show()

160 changes: 155 additions & 5 deletions scripts/onnx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tensorrt as trt
import argparse
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel

def show_mask(mask, ax, random_color=False):
if random_color:
Expand Down Expand Up @@ -70,7 +71,7 @@ def pre_processing(image: np.ndarray, target_length: int, device,pixel_mean,pixe
input_image_torch = F.pad(input_image_torch, (0, padw, 0, padh))
return input_image_torch

def export_embedding_model(gpu_id, model_type, sam_checkpoint):
def export_embedding_model(gpu_id, model_type, sam_checkpoint, opset):
device = f"cuda:{gpu_id}"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
Expand Down Expand Up @@ -98,13 +99,155 @@ def export_embedding_model(gpu_id, model_type, sam_checkpoint):
onnx_model_path,
export_params=True,
verbose=False,
opset_version=17,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
)
print(f"Generate image onnx model, and save in: {onnx_model_path}")


def export_prompt_masks_model(model_type: str, checkpoint: str, opset: int):
print("Loading model...")
sam = sam_model_registry[model_type](checkpoint=checkpoint)

onnx_model = SamOnnxModel(
model=sam,
return_single_mask=False,
use_stability_score=False,
return_extra_metrics=False,
)
onnx_model_path = os.path.join("weights", "sam_" + model_type+"_"+"prompt_mask.onnx")

dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
# "orig_im_size": torch.tensor([1500, 2250], dtype=torch.int32),
}

_ = onnx_model(**dummy_inputs)

output_names = ["low_res_masks", "iou_predictions"]

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(onnx_model_path, "wb") as f:
print(f"Exporting onnx model to {onnx_model_path}...")
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
print(f"Generate prompt and masks onnx model, and save in: {onnx_model_path}")

def export_prompt_model(gpu_id=1, model_type="default", sam_checkpoint="weights/sam_vit_h_4b8939.pth"):
device = f"cuda:{gpu_id}"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

os.makedirs("prompt_onnx", exist_ok=True)

embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
onnx_model_path = os.path.join("prompt_onnx", "sam_" + model_type+"_"+"prompt.onnx")
dynamic_axes = {
"point_coords": {0: "num_points"},
"point_labels": {0: "num_points"},
"boxes": {0: "num_boxes"},
}
points_coord = torch.randint(low=0, high=1024, size=(1, 1, 2), dtype=torch.float).to(device)
points_label = torch.randint(low=0, high=4, size=(1, 1), dtype=torch.float).to(device)
points = (points_coord, points_label)
boxes = torch.randint(low=0, high=1024, size=(1, 1, 4), dtype=torch.int32).to(device)

dummy_inputs = {
"points": points,
"boxes": boxes,
"masks": torch.randn(1, 1, *mask_input_size, dtype=torch.float).to(device),
}
input_names = ["point_coords", "point_labels", "boxes", "mask_input"]

output_names = ["sparse_embeddings", "dense_embeddings"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# with open(onnx_model_path, "wb") as f:
torch.onnx.export(
sam.prompt_encoder,
tuple(dummy_inputs.values()),
onnx_model_path,
export_params=True,
verbose=False,
opset_version=17,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
print(f"Generate image onnx model, and save in: {onnx_model_path}")

def export_masks_model(gpu_id=2, model_type="default", sam_checkpoint="weights/sam_vit_h_4b8939.pth"):
device = f"cuda:{gpu_id}"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

os.makedirs("masks_onnx", exist_ok=True)

onnx_model_path = os.path.join("masks_onnx", "sam_" + model_type+"_"+"masks.onnx")
dynamic_axes = {
"sparse_embeddings": {1: "num_embedding"},
}
sparse_embeddings = torch.randint(low=0, high=1024, size=(1, 2, 256), dtype=torch.float).to(device)
dense_embeddings = torch.randint(low=0, high=1024, size=(1, 256, 64, 64), dtype=torch.float).to(device)
image_embeddings = torch.randint(low=0, high=1024, size=(1, 256, 64, 64), dtype=torch.int32).to(device)
image_pe = torch.randint(low=0, high=1024, size=(1, 256, 64, 64), dtype=torch.int32).to(device)
multimask_output = torch.tensor([0], dtype=torch.float).to(device)

dummy_inputs = {
"image_embeddings": image_embeddings,
"image_pe": image_pe,
"sparse_embeddings": sparse_embeddings,
"dense_embeddings": dense_embeddings,
"multimask_output": multimask_output,
}

output_names = ["low_res_masks", "iou_predictions"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# with open(onnx_model_path, "wb") as f:
torch.onnx.export(
sam.mask_decoder,
tuple(dummy_inputs.values()),
onnx_model_path,
export_params=True,
verbose=False,
opset_version=17,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
print(f"Generate image onnx model, and save in: {onnx_model_path}")

def export_engine_image_encoder(f='vit_l_embedding.onnx', half=True):
file = Path(f)
f = file.with_suffix('.engine') # TensorRT engine file
Expand Down Expand Up @@ -183,18 +326,25 @@ def export_engine_prompt_encoder_and_mask_decoder(f='sam_onnx_example.onnx', hal
parser.add_argument("--img_pt2onnx", action="store_true", help="transform image embedding pth from sam model to onnx")
parser.add_argument("--sam_checkpoint", type=str, default="weights/sam_vit_h_4b8939.pth")
parser.add_argument("--model_type", type=str, default="default")
parser.add_argument("--prompt_masks_pt2onnx", action="store_true", help="whether export prompt encoder and masks decoder module")
parser.add_argument("--img_onnx2trt", action="store_true", help="only transform image embedding onnx model to tensorrt engine")
parser.add_argument("--img_onnx_model_path", type=str, default="embedding_onnx/sam_default_embedding.onnx")
parser.add_argument("--sam_onnx2trt", action="store_true", help="only transform sam prompt and mask decoder onnx model to tensorrt engine")
parser.add_argument("--sam_onnx_path", type=str, default="./weights/sam_vit_h_4b8939.onnx")
parser.add_argument("--gpu_id", type=int, default=0, help="use which gpu to transform model")
parser.add_argument("--opset", type=int, default=17, help="onnx opset version")
args = parser.parse_args()

with torch.no_grad():
if args.img_pt2onnx:
export_embedding_model(args.gpu_id, args.model_type, args.sam_checkpoint)
export_embedding_model(args.gpu_id, args.model_type, args.sam_checkpoint, args.opset)
if args.prompt_masks_pt2onnx:
export_prompt_masks_model(args.model_type, args.sam_checkpoint, args.opset)
if args.img_onnx2trt:
export_engine_image_encoder(args.img_onnx_model_path, False)
if args.sam_onnx2trt:
export_engine_prompt_encoder_and_mask_decoder(args.sam_onnx_path)


# just test split prompt encoder and masks decoder module
# export_prompt_model()
# export_masks_model()
2 changes: 1 addition & 1 deletion utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
@Desc : None
'''

from .sam_function import pre_processing, apply_coords, show_mask, show_points, show_box
from .sam_function import pre_processing, apply_coords, show_mask, show_points, show_box, mask_postprocessing
30 changes: 29 additions & 1 deletion utils/sam_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,32 @@ def pre_processing(image: np.ndarray,
padh = img_size - h
padw = img_size - w
input_image_torch = F.pad(input_image_torch, (0, padw, 0, padh))
return input_image_torch.numpy()
return input_image_torch.numpy()

def resize_longest_image_size(
input_image_size: torch.Tensor, longest_side: int
) -> torch.Tensor:
input_image_size = input_image_size.to(torch.float32)
scale = longest_side / torch.max(input_image_size)
transformed_size = scale * input_image_size
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size

def mask_postprocessing(masks: np.array, orig_im_size: Tuple, img_size: int) -> torch.Tensor:
masks = torch.from_numpy(masks[None, :, :, :]) # (4, 256, 256) -> (1, 4, 256, 256)
orig_im_size = torch.tensor([orig_im_size[0], orig_im_size[1]], dtype=torch.int32)

masks = F.interpolate(
masks,
size=(img_size, img_size),
mode="bilinear",
align_corners=False,
)

prepadded_size = resize_longest_image_size(orig_im_size, img_size).to(torch.int64)
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore

orig_im_size = orig_im_size.to(torch.int64)
h, w = orig_im_size[0], orig_im_size[1]
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
return masks

0 comments on commit 242e0fe

Please sign in to comment.