Skip to content

Commit

Permalink
chore: introduce auto formatting before commit
Browse files Browse the repository at this point in the history
  • Loading branch information
linhandev committed Jul 11, 2023
1 parent 5f5f921 commit ffdc2ba
Show file tree
Hide file tree
Showing 28 changed files with 1,706 additions and 925 deletions.
29 changes: 29 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
# - id: no-commit-to-branch
# args: [--pattern, ^v]
# - id: check-added-large-files
# args: [--maxkb=64]
- id: check-case-conflict
- id: check-yaml
- id: check-xml
- id: check-toml
- id: check-merge-conflict
- id: check-symlinks
- id: destroyed-symlinks
- id: mixed-line-ending
args: [--fix=lf]
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-json
- id: pretty-format-json
args: [--autofix, --indent=4, --no-ensure-ascii]
- id: detect-private-key
- id: fix-encoding-pragma

- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
89 changes: 65 additions & 24 deletions MedSAM_Inference.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,56 @@
# -*- coding: utf-8 -*-
# %% load environment
import numpy as np
import matplotlib.pyplot as plt
import os

join = os.path.join
import torch
from segment_anything import sam_model_registry
from skimage import io, transform
import torch.nn.functional as F
import argparse


# visualization functions
# source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb
# change color to avoid red and green
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([251/255, 252/255, 30/255, 0.6])
color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)



def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
)


@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
box_torch = box_torch[:, None, :] # (B, 1, 4)

sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_logits, _ = medsam_model.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
)

low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)

Expand All @@ -57,17 +64,43 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg

#%% load model and image
parser = argparse.ArgumentParser(description='run inference on testing set based on MedSAM')
parser.add_argument('-i', '--data_path', type=str, default='assets/img_demo.png', help='path to the data folder')
parser.add_argument('-o', '--seg_path', type=str, default='assets/', help='path to the segmentation folder')
parser.add_argument('--box', type=list, default=[95,255, 190, 350], help='bounding box of the segmentation target')
parser.add_argument('--device', type=str, default='cuda:0', help='device')
parser.add_argument('-chk', '--checkpoint', type=str, default='work_dir/MedSAM/medsam_vit_b.pth', help='path to the trained model')

# %% load model and image
parser = argparse.ArgumentParser(
description="run inference on testing set based on MedSAM"
)
parser.add_argument(
"-i",
"--data_path",
type=str,
default="assets/img_demo.png",
help="path to the data folder",
)
parser.add_argument(
"-o",
"--seg_path",
type=str,
default="assets/",
help="path to the segmentation folder",
)
parser.add_argument(
"--box",
type=list,
default=[95, 255, 190, 350],
help="bounding box of the segmentation target",
)
parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument(
"-chk",
"--checkpoint",
type=str,
default="work_dir/MedSAM/medsam_vit_b.pth",
help="path to the trained model",
)
args = parser.parse_args()

device = args.device
medsam_model = sam_model_registry['vit_b'](checkpoint=args.checkpoint)
medsam_model = sam_model_registry["vit_b"](checkpoint=args.checkpoint)
medsam_model = medsam_model.to(device)
medsam_model.eval()

Expand All @@ -77,24 +110,32 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
else:
img_3c = img_np
H, W, _ = img_3c.shape
#%% image preprocessing
img_1024 = transform.resize(img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
# %% image preprocessing
img_1024 = transform.resize(
img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
).astype(np.uint8)
img_1024 = (img_1024 - img_1024.min()) / np.clip(
img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
) # normalize to [0, 1], (H, W, 3)
# convert the shape to (3, H, W)
img_1024_tensor = torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
img_1024_tensor = (
torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
)

box_np = np.array([args.box])
# transfer box_np t0 1024x1024 scale
box_1024 = box_np / np.array([W, H, W, H]) * 1024
with torch.no_grad():
image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64)
image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64)

medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W)
io.imsave(join(args.seg_path, 'seg_' + os.path.basename(args.data_path)), medsam_seg, check_contrast=False)
io.imsave(
join(args.seg_path, "seg_" + os.path.basename(args.data_path)),
medsam_seg,
check_contrast=False,
)

#%% visualize results
# %% visualize results
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img_3c)
show_box(box_np[0], ax[0])
Expand Down
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# MedSAM
# MedSAM
This is the official repository for MedSAM: Segment Anything in Medical Images.


## Installation
## Installation
1. Create a virtual environment `conda create -n medsam python=3.10 -y` and activate it `conda activate medsam`
2. Install [Pytorch 2.0](https://pytorch.org/get-started/locally/)
3. `git clone https://github.com/bowang-lab/MedSAM`
Expand Down Expand Up @@ -43,7 +43,7 @@ python gui.py

Load the image to the GUI and specify segmentation targets by drawing bounding boxes.

![seg_demo](assets/seg_demo.gif)
![seg_demo](assets/seg_demo.gif)


## Model Training
Expand Down Expand Up @@ -77,9 +77,9 @@ The model was trained on five A100 nodes and each node has four GPUs (80G) (20 A

```bash
sbatch train_multi_gpus.sh
```
```

When the training process is done, please convert the checkpoint to SAM's format for convenient inference.
When the training process is done, please convert the checkpoint to SAM's format for convenient inference.

```bash
python utils/ckpt_convert.py # Please set the corresponding checkpoint path first
Expand All @@ -91,11 +91,11 @@ python utils/ckpt_convert.py # Please set the corresponding checkpoint path firs
python train_one_gpu.py
```

If you only want to train the mask decoder, please check the tutorial on the [0.1 branch](https://github.com/bowang-lab/MedSAM/tree/0.1).
If you only want to train the mask decoder, please check the tutorial on the [0.1 branch](https://github.com/bowang-lab/MedSAM/tree/0.1).


## Acknowledgements
- We highly appreciate all the challenge organizers and dataset owners for providing the public dataset to the community.
- We highly appreciate all the challenge organizers and dataset owners for providing the public dataset to the community.
- We thank Meta AI for making the source code of [segment anything](https://github.com/facebookresearch/segment-anything) publicly available.
- We also thank Alexandre Bonnet for sharing this great [blog](https://encord.com/blog/learn-how-to-fine-tune-the-segment-anything-model-sam/)

Expand Down
4 changes: 3 additions & 1 deletion gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
MEDSAM_IMG_INPUT_SIZE = 1024
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, height, width):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
Expand Down Expand Up @@ -89,6 +90,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, height, width):

print(f"MedSam loaded, took {time.perf_counter() - tic}")


def np2pixmap(np_img):
height, width, channel = np_img.shape
bytesPerLine = 3 * width
Expand Down Expand Up @@ -274,7 +276,7 @@ def mouse_release(self, ev):

H, W, _ = self.img_3c.shape
box_np = np.array([[xmin, ymin, xmax, ymax]])
print('bounding box:', box_np)
print("bounding box:", box_np)
box_1024 = box_np / np.array([W, H, W, H]) * 1024

img_1024 = transform.resize(
Expand Down
2 changes: 1 addition & 1 deletion medsam_inference.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit ffdc2ba

Please sign in to comment.