This repository contains several multi-task extensions of a U-Net model[1] to improve segmentation results on a small ultrasound nerve dataset. Our approaches were guided by TUM's chair for Computer Aided Medical Procedures. We applied multi-task learning[2] with a U-Net model to improve segmentation results on a very limited dataset. We implemented multiple architectures including hard parameter sharing using an FCN classifier at the U-net bottleneck, soft parameter sharing using cross-stitch networks[3] as well as a ResNet-18 benchmark classifier. Our classifiers used cross entropy loss and segmenters used dice loss. We experimented with several multitask loss approaches including linear weighting of classification and segmentation loss, uncertainty weighting[4] and loss scheduling[5].
model/quicknat.py
- Vanilla QuickNAT architecture very similar to U-Net but adapted for fast brain image segmentation. This network serves as a benchmark for segmentation results.
model/resultnet.py
- PyTorch ResNet-18 adapted for nerve image classification. Similarly serves as a gold standard for classification results. Note that our multitask networks use an encoder + FCN classifier so we won't expect their accuracy to be as high as ResNet.
model/quickfcn.py
- Hard parameter sharing model [2]: QuickNAT nerve segmentation network with a fully connected layer for nerve classification attached to bottleneck.
model/softquickfcn.py
- Soft parameter sharing model [2]: QuickNAT nerve segmention network and a separate nerve classifier network with an identical encoder. Both networks are independently pretrained on their specific tasks. Encoders of both networks are then joined using Cross-stitch networks[3] for a second round of training.
We compare two different multitask learning extensions of QuickNAT. Hard parameter sharing uses the same encoder and bottleneck weights for classification and segmentation while soft parameter sharing optionally shares cross-stitch wrights between independent encoders.
Below we see that both networks segment nerves well when the nerve class is predicted correctly (ground truth in red and prediction in blue). However when a nerve is misclassified, hard parameter sharing fails to predict an accurate segmentation. We hypothesize that this is because the tasks of segmentation and classification are quite distinct; the more the network learns about classification, the less it knows about segmentation. In contrast, soft parameter sharing segments nerves correctly even when they are misclassified. This is expected, because the classification and segmentation networks are still free to learn separately.
Note: Replace image
- PyTorch - Python deep learning library
- Tensorboard - Visualization of losses, metrics, segmentation results and confusion matrices.
- polyaxon - GPU cluster scheduling
- Abhijit Guha Roy, Sailesh Conjeti, Nassir Navab, Christian Wachinger (2018). QuickNAT: Segmenting {MRI} Neuroanatomy in 20 seconds. CoRR
- Sebastian Ruder (2017). An Overview of Multi-Task Learning in Deep Neural Networks, CVPR
- Ishan Misra, Abhinav Shrivastava, Abhinav Gupta, Martial Hebert (2017). Cross-stitch Networks for Multi-task Learning, CVPR
- Alex Kendall, Yarin Gal and Roberto Cipolla, Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics, CoRR 2017
- Sailesh Conjeti, Magdalini Paschali, Amin Katouzian, Nassir Navab (2017), Learning Robust Hash Codes for Multiple Instance Image Retrieval, MICCAI 2017