This folder contains the training and inference scripts of the DeepLabV3+ model for segmentation on medical image data in MedSAM's preprocessed npz
format. For details regarding the data preprocessing pipeline, please refer to the MedSAM.
This codebase uses the Segmentation Models Pytorch, which can be installed via pip:
pip install segmentation-models-pytorch
To train the DeepLabV3+ model, one can use the provided train_deeplabv3_res50.py
script. In order to incorporate the bounding box prompts into the model, we converted the bounding box as a binary mask and concatenated it with the image as the model input. The bounding box was simulated based on ground truth.
Below are the required parameters that need to be configured before training:
-i /path/to/input
: Path to the input dataset (npy format).-o /path/to/output
: Path to save the trained model.
Example command for training:
python train_deeplabv3_res50.py \
-i /path/to/input \
-o /path/to/output \
-b ## batch size \
--num_workers 4 \ ## Number of workers for data loading
--max_epochs 500 \ ## Maximum number of epochs to train
--compile ## Whether to compile the model for acceleration
The inference scripts assume that the data is in the npz
format generated by MedSAM preprocess pipeline. To run inference, one can download the model here and use the provided inference scripts.
To perform inference on 2D images, one can use the infer_deeplabv3_res50_2D.py
script. Below are the parameters need to be configured:
-checkpoint
: Path to the trained model checkpoint.-data_root
: Path to the input images.-pred_save_dir
: Path to save the output segmented images.--save_overlay
: Save the overlay of the segmentation on the original image. (Optional)-png_save_dir
: Path to save the overlay images. (Required if--save_overlay
is used)-num_workers
: Number of workers for multiprocessing during inference.--grey
: Save the overlay images in greyscale. (Optional)
python infer_deeplabv3_res50_2D.py \
-checkpoint path/to/checkpoint/deeplabv3plus_best.pt \
-data_root /path/to/input \
-pred_save_dir /path/to/output \
--save_overlay \
-png_save_dir /path/to/saved/overlay \
-num_workers 2 \
--grey
To perform inference on 3D medical images, such as those of CT or MR modality, the infer_deeplabv3_res50_3D.py
script can be used. Below are the parameters that one can configure:
-checkpoint
: Path to the trained model checkpoint.-data_root
: Path to the input 3D images.-pred_save_dir
: Path to save the output segmented 3D images.-png_save_dir
: Path to save the overlay images. (Optional)-num_workers
: Number of workers for multiprocessing during inference.
python infer_deeplabv3_res50_3D.py \
-checkpoint /path/to/checkpoint/deeplabv3plus_best.pt \
-data_root /path/to/input \
-pred_save_dir /path/to/output \
-png_save_dir /path/to/saved/overlay \
-num_workers 2
This codebasse uses the Segmentation Models Pytorch repository. We would like to thank the authors and the contributors for their great work and for making the code publicly available.