Code for paper: DCOP-Net: A dual-filter cross attention and onion pooling network for few-shot medical image segmentation
Few-shot learning has demonstrated remarkable performance in medical image segmentation. In our manuscript, we propose a dual-filter cross-attention and onion pooling network (DCOP-Net) for FSMIS. Our model contains two stages: a prototype learning stage and a segmentation stage. During the prototype learning stage, we design a Dual-Filter Cross Attention (DFCA) module and an Onion Pooling (OP) module. In the segmentation stage, we present a Parallel Threshold Perception (PTP) module and a Query Self-Reference Regularization (QSR) strategy. Specifically,
1)The DFCA module utilizes a prior mask and adaptive attention filtering method to filter background factors in the feature map from two aspects, effectively integrating query foreground features into support features.
2)The OP module generates multiple masks using the proposed erode pooling and combines masked average pooling to extrate multiple prototypes, effectively preserve contextual information in the feature map.
3)The PTP module combines maximum pooling and average pooling dual paths to process the features. After processing the spliced results in the fully connected layer, the module obtains robust parameters for thresholding the query image anomaly score map.
4)The QSR strategy utilizes the prediction results and the query image to generate a prototype, and then segments the query image to obtain a new loss. This forms a feedback mechanism for the model and improves the accuracy and consistency of the segmentation of the model.
Please install following essential dependencies:
dcm2nii
json5==0.8.5
jupyter==1.0.0
nibabel==2.5.1
numpy==1.22.0
opencv-python==4.5.5.62
Pillow>=8.1.1
sacred==0.8.2
scikit-image==0.18.3
SimpleITK==1.2.3
torch==1.10.2
torchvision=0.11.2
tqdm==4.62.3
IDE: PyCharm 2022.3 Community Edition.
Framework: PyTorch 2.0.1.
Language: Python 3.11.2
CUDA: 12.1
Pre-processing is performed according to Ouyang et al. and we follow the procedure on their github repository.
The pre-processed data and supervoxels can be downloaded by:
- Pre-processed CHAOS-T2 data and supervoxels
- Pre-processed SABS data and supervoxels
- Pre-processed CMR data and supervoxels
- Compile
./supervoxels/felzenszwalb_3d_cy.pyx
with cython (python ./supervoxels/setup.py build_ext --inplace
) and run./supervoxels/generate_supervoxels.py
- Download pre-trained ResNet-101 weights vanilla version or deeplabv3 version and put your checkpoints folder, then replace the absolute path in the code
./models/encoder.py
. - Run
./script/train.sh
Run ./script/test.sh
Code is based the works: RPTNet ,SSL-ALPNet, ADNet and QNet