Skip to content

Commit

Permalink
add details
Browse files Browse the repository at this point in the history
  • Loading branch information
JunMa11 committed Aug 29, 2023
1 parent 677d5ff commit 3e703df
Show file tree
Hide file tree
Showing 22 changed files with 2,586 additions and 3 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
*.npy
*.npz
*.csv
*.pdf
*.log
*.png
data/
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Segment other images with the following flags

We provide a step-by-step tutorial on [CoLab](https://colab.research.google.com/drive/19WNtRMbpsxeqimBlmJwtd1dzpaIvK2FZ?usp=sharing)

You can also run it locally with `MedSAM_Inference.ipynb`.
You can also run it locally with `tutorial_quickstart.ipynb`.

3. GUI

Expand Down Expand Up @@ -109,7 +109,7 @@ If you only want to train the mask decoder, please check the tutorial on the [0.
```
@article{MedSAM,
title={Segment Anything in Medical Images},
author={Ma, Jun and He, Yuting and Li, Feifei and Han, Lin and You, Chenyu and and Wang, Bo},
author={Ma, Jun and He, Yuting and Li, Feifei and Han, Lin and You, Chenyu and Wang, Bo},
journal={arXiv preprint arXiv:2304.12306},
year={2023}
}
Expand Down
Binary file added assets/MedSAM_supp.pdf
Binary file not shown.
81 changes: 81 additions & 0 deletions comparisons/DeepLabV3+/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# DeepLabV3+ Training and Inference Scripts

This folder contains the training and inference scripts of the DeepLabV3+ model for segmentation on medical image data in MedSAM's preprocessed `npz` format. For details regarding the data preprocessing pipeline, please refer to the [MedSAM](https://github.com/bowang-lab/MedSAM#data-preprocessing).

## Prerequisites

This codebase uses the [Segmentation Models Pytorch](https://github.com/qubvel/segmentation_models.pytorch), which can be installed via pip:

```
pip install segmentation-models-pytorch
```

## Training

To train the DeepLabV3+ model, one can use the provided `train_deeplabv3_res50.py` script. In order to incorporate the bounding box prompts into the model, we converted the bounding box as a binary mask and concatenated it with the image as the model input. The bounding box was simulated based on ground truth.
Below are the required parameters that need to be configured before training:

* `-i /path/to/input`: Path to the input dataset (npy format).
* `-o /path/to/output`: Path to save the trained model.

Example command for training:
```sh
python train_deeplabv3_res50.py \
-i /path/to/input \
-o /path/to/output \
-b ## batch size \
--num_workers 4 \ ## Number of workers for data loading
--max_epochs 500 \ ## Maximum number of epochs to train
--compile ## Whether to compile the model for acceleration
```



## Inference
The inference scripts assume that the data is in the `npz` format generated by MedSAM preprocess pipeline. To run inference, one can download the model [here](https://drive.google.com/drive/folders/1xYUgdjIsmBkobiBKXNb1uyqN-kGHW2p_?usp=sharing) and use the provided inference scripts.


### Inference for 2D images

To perform inference on 2D images, one can use the `infer_deeplabv3_res50_2D.py` script. Below are the parameters need to be configured:

* `-checkpoint`: Path to the trained model checkpoint.
* `-data_root`: Path to the input images.
* `-pred_save_dir`: Path to save the output segmented images.
* `--save_overlay`: Save the overlay of the segmentation on the original image. (Optional)
* `-png_save_dir`: Path to save the overlay images. (Required if `--save_overlay` is used)
* `-num_workers`: Number of workers for multiprocessing during inference.
* `--grey`: Save the overlay images in greyscale. (Optional)

```sh
python infer_deeplabv3_res50_2D.py \
-checkpoint path/to/checkpoint/deeplabv3plus_best.pt \
-data_root /path/to/input \
-pred_save_dir /path/to/output \
--save_overlay \
-png_save_dir /path/to/saved/overlay \
-num_workers 2 \
--grey
```

### Inference for 3D images

To perform inference on 3D medical images, such as those of CT or MR modality, the `infer_deeplabv3_res50_3D.py` script can be used. Below are the parameters that one can configure:

* `-checkpoint`: Path to the trained model checkpoint.
* `-data_root`: Path to the input 3D images.
* `-pred_save_dir`: Path to save the output segmented 3D images.
* `-png_save_dir`: Path to save the overlay images. (Optional)
* `-num_workers`: Number of workers for multiprocessing during inference.

```sh
python infer_deeplabv3_res50_3D.py \
-checkpoint /path/to/checkpoint/deeplabv3plus_best.pt \
-data_root /path/to/input \
-pred_save_dir /path/to/output \
-png_save_dir /path/to/saved/overlay \
-num_workers 2
```

## Acknowledgement
This codebasse uses the [Segmentation Models Pytorch](https://github.com/qubvel/segmentation_models.pytorch) repository. We would like to thank the authors and the contributors for their great work and for making the code publicly available.
259 changes: 259 additions & 0 deletions comparisons/DeepLabV3+/infer_deeplabv3_res50_2D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# %%
import os
import glob
import random
from os import listdir, makedirs
from os.path import join, isdir, basename, dirname, isfile
from tqdm import tqdm
import numpy as np
import torch
from torch._dynamo import OptimizedModule
from torch import multiprocessing as mp
import cv2
import torch.nn.functional as F
from matplotlib import pyplot as plt
import segmentation_models_pytorch as smp
import argparse

torch.cuda.empty_cache()
os.environ['PYTHONHASHSEED']=str(2023)
random.seed(2023)
np.random.seed(2023)
torch.manual_seed(2023)
torch.cuda.manual_seed(2023)

parser = argparse.ArgumentParser()
parser.add_argument(
'-checkpoint',
type=str,
default='',
help='Path to the model checkpoint',
required=True
)
parser.add_argument('-device', type=str, default='cuda:0')
parser.add_argument(
'-data_root',
type=str,
default='',
help='Path to the validation data directory',
required=True
)
parser.add_argument(
'-pred_save_dir',
type=str,
default='segs',
help='Path to the directory where the segmentation results will be saved in npz format'
)
parser.add_argument('--save_overlay', action='store_true', default=False, help="Whether to save segmentation overlay")
parser.add_argument(
'-png_save_dir',
type=str,
default='png',
help='Path to the directory where the segmentation overlay will be saved in png format'
)
parser.add_argument(
'--grey',
action='store_true',
default=False,
help="Whether to save segmentation overlay in grey scale"
)
parser.add_argument('-num_workers', type=int, default=1, help='number of workers for dataloader')

args = parser.parse_args()
checkpoint = args.checkpoint
device = args.device
data_root = args.data_root
pred_save_dir = args.pred_save_dir
png_save_dir = args.png_save_dir
makedirs(pred_save_dir, exist_ok=True)
save_overlay = args.save_overlay
if save_overlay:
makedirs(png_save_dir, exist_ok=True)
num_workers = args.num_workers
data_root_files = listdir(data_root)
has_task = isdir(join(data_root, data_root_files[0]))
if has_task:
gt_path_files = sorted(glob.glob(join(data_root, '**/*.npz'), recursive=True))
else:
gt_path_files = sorted(glob.glob(join(data_root, '*.npz'), recursive=True))
image_size = 224
bbox_shift = 5

def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.45])], axis=0)
else:
color = np.array([251/255, 252/255, 30/255, 0.45])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)

def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))

def dice_coefficient(preds, targets):
smooth = 1.0
assert preds.shape == targets.shape

intersection = (preds * targets).sum()
dice = (2.0 * intersection + smooth) / (preds.sum() + targets.sum() + smooth)
return dice

# %%
model = smp.DeepLabV3Plus(
encoder_name="resnet50", # encoder model type
encoder_weights="imagenet", # use `imagenet` pretrained weights for encoder initialization
in_channels=4, # Additional channel for bounding box prompt
classes=1, # model output channels (number of classes in your dataset)
activation=None # Output logits
)
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()

def preprocess_image(img_3c, gt_2D, image_size=224, bbox_shift=5):
"""
Append bounding box prompt channel to image
"""
resize_img_cv2 = cv2.resize(
img_3c,
(image_size, image_size),
interpolation=cv2.INTER_AREA
)
resize_img_cv2_01 = (resize_img_cv2 - resize_img_cv2.min()) / np.clip(resize_img_cv2.max() - resize_img_cv2.min(), a_min=1e-8, a_max=None) # normalize to [0, 1], (H, W, 3)
# convert the shape to (3, H, W)
resize_img = np.transpose(resize_img_cv2_01, (2, 0, 1))
assert np.max(resize_img)<=1.0 and np.min(resize_img)>=0.0, 'image should be normalized to [0, 1]'
if gt_2D.shape[0] != image_size or gt_2D.shape[1] != image_size:
gt_2D = cv2.resize(
gt_2D, (image_size, image_size),
interpolation=cv2.INTER_NEAREST
)
gt_2D = np.uint8(gt_2D)
else:
gt_2D = gt_2D.astype(np.uint8)
try:
assert np.max(gt_2D) == 1 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D))
except:
assert np.max(gt_2D) == 0 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D))
return None

y_indices, x_indices = np.where(gt_2D > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
H, W = gt_2D.shape
x_min = max(0, x_min - bbox_shift)
x_max = min(W, x_max + bbox_shift)
y_min = max(0, y_min - bbox_shift)
y_max = min(H, y_max + bbox_shift)
bboxes = np.array([x_min, y_min, x_max, y_max])

## Append bbox prompt channel
resize_img_bbox = np.concatenate([resize_img, np.zeros((1, image_size, image_size))], axis=0)
resize_img_bbox[-1, y_min:y_max, x_min:x_max] = 1.0
resize_img_bbox = resize_img_bbox[None, ...]

return torch.tensor(resize_img_bbox).float()


def deeplabv3plus_infer_npz(gt_path_file):
npz_name = basename(gt_path_file)
if has_task:
task_folder = gt_path_file.split('/')[-2]
pred_save_dir_task = join(pred_save_dir, task_folder)
png_save_dir_task = join(png_save_dir, task_folder)
makedirs(pred_save_dir_task, exist_ok=True)
makedirs(png_save_dir_task, exist_ok=True)
else:
pred_save_dir_task = pred_save_dir
png_save_dir_task = png_save_dir
if isfile(join(pred_save_dir_task, npz_name)):
return
npz = np.load(gt_path_file, 'r', allow_pickle=True)
img_3c = npz['imgs'] # (Num, H, W)
gts = npz['gts'] # (Num, 256, 256)
segs = np.zeros_like(img_3c[..., 0], dtype=np.uint8)

label_ids = np.unique(gts)[1:]

for label_id in label_ids:
gt_2D = np.uint8(gts == label_id) # only one label
img_4c = preprocess_image(
img_3c,
gt_2D,
image_size=image_size,
bbox_shift=bbox_shift
)
if img_4c == None:
continue
img_4c = img_4c.to(device)
with torch.no_grad():
seg_logits = model(img_4c)
seg_logits = F.interpolate(
seg_logits,
size=img_3c.shape[:2],
mode='bilinear',
align_corners=False
)
seg_probs = torch.sigmoid(seg_logits)
seg_probs_np = seg_probs.detach().cpu().numpy().squeeze()
torch.cuda.empty_cache()
seg_2D = np.uint8(seg_probs_np > 0.5)
segs[seg_2D > 0] = label_id

if gts.shape[0] != img_3c.shape[0] or gts.shape[1] != img_3c.shape[1]:
gts = cv2.resize(
gts,
(img_3c.shape[1], img_3c.shape[0]),
interpolation=cv2.INTER_NEAREST
)

np.savez_compressed(
join(pred_save_dir_task, npz_name),
segs=segs,
gts=gts
)

if save_overlay:
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
if args.grey:
ax[0].imshow(img_3c, cmap='gray')
else:
ax[0].imshow(img_3c)
ax[0].set_title("Image")
ax[0].axis('off')
if args.grey:
ax[1].imshow(img_3c, cmap='gray')
else:
ax[1].imshow(img_3c)
show_mask(gts, ax[1])
ax[1].axis('off')
#show_box(boxes_np, ax[1])
ax[1].set_title("Ground Truth")
if args.grey:
ax[2].imshow(img_3c, cmap='gray')
else:
ax[2].imshow(img_3c)
show_mask(segs, ax[2])
ax[2].set_title("Segmentation")
ax[2].axis('off')
plt.savefig(
join(png_save_dir_task, npz_name.split(".")[0] + '.png'),
dpi=300
)
plt.close()

if __name__ == '__main__':
num_workers = num_workers
try:
mp.set_start_method('spawn', force=True)
print("spawned")
except RuntimeError:
pass
with mp.Pool(processes=num_workers) as pool:
with tqdm(total=len(gt_path_files)) as pbar:
for i, _ in tqdm(enumerate(pool.imap_unordered(deeplabv3plus_infer_npz, gt_path_files))):
pbar.update()
Loading

0 comments on commit 3e703df

Please sign in to comment.