Skip to content

Latest commit

 

History

History
 
 

DeepLabV3+

DeepLabV3+ Training and Inference Scripts

This folder contains the training and inference scripts of the DeepLabV3+ model for segmentation on medical image data in MedSAM's preprocessed npz format. For details regarding the data preprocessing pipeline, please refer to the MedSAM.

Prerequisites

This codebase uses the Segmentation Models Pytorch, which can be installed via pip:

pip install segmentation-models-pytorch

Training

To train the DeepLabV3+ model, one can use the provided train_deeplabv3_res50.py script. In order to incorporate the bounding box prompts into the model, we converted the bounding box as a binary mask and concatenated it with the image as the model input. The bounding box was simulated based on ground truth. Below are the required parameters that need to be configured before training:

  • -i /path/to/input: Path to the input dataset (npy format).
  • -o /path/to/output: Path to save the trained model.

Example command for training:

python train_deeplabv3_res50.py \
    -i /path/to/input \
    -o /path/to/output \
    -b ## batch size \
    --num_workers 4 \ ## Number of workers for data loading
    --max_epochs 500 \ ## Maximum number of epochs to train
    --compile ## Whether to compile the model for acceleration

Inference

The inference scripts assume that the data is in the npz format generated by MedSAM preprocess pipeline. To run inference, one can download the model here and use the provided inference scripts.

Inference for 2D images

To perform inference on 2D images, one can use the infer_deeplabv3_res50_2D.py script. Below are the parameters need to be configured:

  • -checkpoint: Path to the trained model checkpoint.
  • -data_root: Path to the input images.
  • -pred_save_dir: Path to save the output segmented images.
  • --save_overlay: Save the overlay of the segmentation on the original image. (Optional)
  • -png_save_dir: Path to save the overlay images. (Required if --save_overlay is used)
  • -num_workers: Number of workers for multiprocessing during inference.
  • --grey: Save the overlay images in greyscale. (Optional)
python infer_deeplabv3_res50_2D.py \
    -checkpoint path/to/checkpoint/deeplabv3plus_best.pt \
    -data_root /path/to/input \
    -pred_save_dir /path/to/output \
    --save_overlay \
    -png_save_dir /path/to/saved/overlay \
    -num_workers 2 \
    --grey

Inference for 3D images

To perform inference on 3D medical images, such as those of CT or MR modality, the infer_deeplabv3_res50_3D.py script can be used. Below are the parameters that one can configure:

  • -checkpoint: Path to the trained model checkpoint.
  • -data_root: Path to the input 3D images.
  • -pred_save_dir: Path to save the output segmented 3D images.
  • -png_save_dir: Path to save the overlay images. (Optional)
  • -num_workers: Number of workers for multiprocessing during inference.
python infer_deeplabv3_res50_3D.py \
    -checkpoint /path/to/checkpoint/deeplabv3plus_best.pt \
    -data_root /path/to/input \
    -pred_save_dir /path/to/output \
    -png_save_dir /path/to/saved/overlay \
    -num_workers 2

Acknowledgement

This codebasse uses the Segmentation Models Pytorch repository. We would like to thank the authors and the contributors for their great work and for making the code publicly available.