Skip to content

Commit

Permalink
Working Official Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lalonderodney committed May 8, 2018
1 parent 14ff393 commit 0d35294
Show file tree
Hide file tree
Showing 21 changed files with 4,484 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
**/.idea/
**.pyc
**/saved_models/
**/*.py~

69 changes: 67 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,67 @@
# SegCaps
Official Implementation of the Paper "Capsules for Object Segmentation".
# Capsules for Object Segmentation (SegCaps)
### by [Rodney LaLonde](https://rodneylalonde.wixsite.com/personal) and [Ulas Bagci](http://www.cs.ucf.edu/~bagci/)

## This repo is the official implementation of SegCaps

The original paper for SegCaps can be found at https://arxiv.org/abs/1804.04241.

A project page for this work can be found at https://rodneylalonde.wixsite.com/personal/research-blog/capsules-for-object-segmentation.

<img src="imgs/qualitative1.png" width="900px"/>

## Condensed Abstract
Convolutional neural networks (CNNs) have shown remarkable results over the last several years for a wide range of computer vision tasks. A new architecture recently introduced by [Sabour et al., referred to as a capsule networks with dynamic routing](https://arxiv.org/abs/1710.09829), has shown great initial results for digit recognition and small image classification. Our work expands the use of capsule networks to the task of object segmentation for the first time in the literature. We extend the idea of convolutional capsules with *locally-connected routing* and propose the concept of *deconvolutional capsules*. Further, we extend the masked reconstruction to reconstruct the positive input class. The proposed convolutional-deconvolutional capsule network, called **SegCaps**, shows strong results for the task of object segmentation with substantial decrease in parameter space. As an example application, we applied the proposed SegCaps to segment pathological lungs from low dose CT scans and compared its accuracy and efficiency with other U-Net-based architectures. SegCaps is able to handle large image sizes (512 x 512) as opposed to baseline capsules (typically less than 32 x 32). The proposed SegCaps reduced the number of parameters of U-Net architecture by **95.4%** while still providing a better segmentation accuracy.

## Baseline Capsule Network for Object Segmentation

<img src="imgs/baselinecaps.png" width="900px"/>

## SegCaps (R3) Network Overview

<img src="imgs/segcaps.png" width="900px"/>

## Quantative Results on the LUNA16 Dataset

| Method | Parameters | Split-0 (%) | Split-1 (%) | Split-2 (%) | Split-3 (%) | Average (%) |
|:---------------- |:----------:|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|
| U-Net | 31.0 M | 98.353 | 98.432 | 98.476 | **98.510** | 98.449 |
| Tiramisu | 2.3 M | 98.394 | 98.358 | **98.543** | 98.339 | 98.410 |
| Baseline Caps | 1.7 M | 82.287 | 79.939 | 95.121 | 83.608 | 83.424 |
| SegCaps (R1) | **1.4 M** | 98.471 | 98.444 | 98.401 | 98.362 | 98.419 |
| **SegCaps (R3)** | **1.4 M** | **98.499** | **98.523** | 98.455 | 98.474 | **98.479** |

## Results of Manipulating the Segmentation Capsule Vectors

<img src="imgs/manip_cropped.png" width="900px"/>

## Getting Started Guide

### Install Required Packages
This repo of code is written for Keras using the TensorFlow backend. Please install all required packages before using this code.
```bash
pip install -r requirements.txt
```

### Dataset Structure

Inside the data root folder (*i.e.* where you have your data stored) you should have two folders: one called *imgs* and one called *masks*. All models, results, etc. are saved to this same root directory.

### Main File

From the main file (main.py) you can train, test, and manipulate the segmentation capsules of various networks. Simply set the ```--train```, ```--test```, or ```--manip flags``` to 0 or 1 to turn these off or on respectively. The argument ```--data_root_dir``` is the only required argument and should be set to the directory containing your *imgs* and *masks* folders. There are many more arguments that can be set and these are all explained in the main.py file.

### Citation

If you use significant portions of our code or ideas from our paper in your research, please cite our work:
```
@article{lalonde2018capsules,
title={Capsules for Object Segmentation},
author={LaLonde, Rodney and Bagci, Ulas},
journal={arXiv preprint arXiv:1804.04241},
year={2018}
}
```

### Questions or Comments

Please direct any questions or comments to me; I am happy to help in any way I can. You can either comment on the [project page](https://rodneylalonde.wixsite.com/personal/research-blog/capsules-for-object-segmentation), or email me directly at [email protected].
280 changes: 280 additions & 0 deletions capsnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
'''
Capsules for Object Segmentation (SegCaps)
Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241)
Code written by: Rodney LaLonde
If you use significant portions of this code or the ideas from our paper, please cite it :)
If you have any questions, please email me at [email protected].
This file contains the network definitions for the various capsule network architectures.
'''

from keras import layers, models
from keras import backend as K
K.set_image_data_format('channels_last')

from capsule_layers import ConvCapsuleLayer, DeconvCapsuleLayer, Mask, Length

def CapsNetR3(input_shape, n_class=2):
x = layers.Input(shape=input_shape)

# Layer 1: Just a conventional Conv2D layer
conv1 = layers.Conv2D(filters=16, kernel_size=5, strides=1, padding='same', activation='relu', name='conv1')(x)

# Reshape layer to be 1 capsule x [filters] atoms
_, H, W, C = conv1.get_shape()
conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1)

# Layer 1: Primary Capsule: Conv cap with routing 1
primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=2, num_atoms=16, strides=2, padding='same',
routings=1, name='primarycaps')(conv1_reshaped)

# Layer 2: Convolutional Capsule
conv_cap_2_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=16, strides=1, padding='same',
routings=3, name='conv_cap_2_1')(primary_caps)

# Layer 2: Convolutional Capsule
conv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=32, strides=2, padding='same',
routings=3, name='conv_cap_2_2')(conv_cap_2_1)

# Layer 3: Convolutional Capsule
conv_cap_3_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same',
routings=3, name='conv_cap_3_1')(conv_cap_2_2)

# Layer 3: Convolutional Capsule
conv_cap_3_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=64, strides=2, padding='same',
routings=3, name='conv_cap_3_2')(conv_cap_3_1)

# Layer 4: Convolutional Capsule
conv_cap_4_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same',
routings=3, name='conv_cap_4_1')(conv_cap_3_2)

# Layer 1 Up: Deconvolutional Capsule
deconv_cap_1_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=8, num_atoms=32, upsamp_type='deconv',
scaling=2, padding='same', routings=3,
name='deconv_cap_1_1')(conv_cap_4_1)

# Skip connection
up_1 = layers.Concatenate(axis=-2, name='up_1')([deconv_cap_1_1, conv_cap_3_1])

# Layer 1 Up: Deconvolutional Capsule
deconv_cap_1_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=32, strides=1,
padding='same', routings=3, name='deconv_cap_1_2')(up_1)

# Layer 2 Up: Deconvolutional Capsule
deconv_cap_2_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=4, num_atoms=16, upsamp_type='deconv',
scaling=2, padding='same', routings=3,
name='deconv_cap_2_1')(deconv_cap_1_2)

# Skip connection
up_2 = layers.Concatenate(axis=-2, name='up_2')([deconv_cap_2_1, conv_cap_2_1])

# Layer 2 Up: Deconvolutional Capsule
deconv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=16, strides=1,
padding='same', routings=3, name='deconv_cap_2_2')(up_2)

# Layer 3 Up: Deconvolutional Capsule
deconv_cap_3_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=2, num_atoms=16, upsamp_type='deconv',
scaling=2, padding='same', routings=3,
name='deconv_cap_3_1')(deconv_cap_2_2)

# Skip connection
up_3 = layers.Concatenate(axis=-2, name='up_3')([deconv_cap_3_1, conv1_reshaped])

# Layer 4: Convolutional Capsule: 1x1
seg_caps = ConvCapsuleLayer(kernel_size=1, num_capsule=1, num_atoms=16, strides=1, padding='same',
routings=3, name='seg_caps')(up_3)

# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps)

# Decoder network.
_, H, W, C, A = seg_caps.get_shape()
y = layers.Input(shape=input_shape[:-1]+(1,))
masked_by_y = Mask()([seg_caps, y]) # The true label is used to mask the output of capsule layer. For training
masked = Mask()(seg_caps) # Mask using the capsule with maximal length. For prediction

def shared_decoder(mask_layer):
recon_remove_dim = layers.Reshape((H.value, W.value, A.value))(mask_layer)

recon_1 = layers.Conv2D(filters=64, kernel_size=1, padding='same', kernel_initializer='he_normal',
activation='relu', name='recon_1')(recon_remove_dim)

recon_2 = layers.Conv2D(filters=128, kernel_size=1, padding='same', kernel_initializer='he_normal',
activation='relu', name='recon_2')(recon_1)

out_recon = layers.Conv2D(filters=1, kernel_size=1, padding='same', kernel_initializer='he_normal',
activation='sigmoid', name='out_recon')(recon_2)

return out_recon

# Models for training and evaluation (prediction)
train_model = models.Model(inputs=[x, y], outputs=[out_seg, shared_decoder(masked_by_y)])
eval_model = models.Model(inputs=x, outputs=[out_seg, shared_decoder(masked)])

# manipulate model
noise = layers.Input(shape=((H.value, W.value, C.value, A.value)))
noised_seg_caps = layers.Add()([seg_caps, noise])
masked_noised_y = Mask()([noised_seg_caps, y])
manipulate_model = models.Model(inputs=[x, y, noise], outputs=shared_decoder(masked_noised_y))

return train_model, eval_model, manipulate_model


def CapsNetR1(input_shape, n_class=2):
x = layers.Input(shape=input_shape)

# Layer 1: Just a conventional Conv2D layer
conv1 = layers.Conv2D(filters=16, kernel_size=5, strides=1, padding='same', activation='relu', name='conv1')(x)

# Reshape layer to be 1 capsule x [filters] atoms
_, H, W, C = conv1.get_shape()
conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1)

# Layer 1: Primary Capsule: Conv cap with routing 1
primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=2, num_atoms=16, strides=2, padding='same',
routings=1, name='primarycaps')(conv1_reshaped)

# Layer 2: Convolutional Capsule
conv_cap_2_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=16, strides=1, padding='same',
routings=1, name='conv_cap_2_1')(primary_caps)

# Layer 2: Convolutional Capsule
conv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=32, strides=2, padding='same',
routings=3, name='conv_cap_2_2')(conv_cap_2_1)

# Layer 3: Convolutional Capsule
conv_cap_3_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same',
routings=1, name='conv_cap_3_1')(conv_cap_2_2)

# Layer 3: Convolutional Capsule
conv_cap_3_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=64, strides=2, padding='same',
routings=3, name='conv_cap_3_2')(conv_cap_3_1)

# Layer 4: Convolutional Capsule
conv_cap_4_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same',
routings=1, name='conv_cap_4_1')(conv_cap_3_2)

# Layer 1 Up: Deconvolutional Capsule
deconv_cap_1_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=8, num_atoms=32, upsamp_type='deconv',
scaling=2, padding='same', routings=3,
name='deconv_cap_1_1')(conv_cap_4_1)

# Skip connection
up_1 = layers.Concatenate(axis=-2, name='up_1')([deconv_cap_1_1, conv_cap_3_1])

# Layer 1 Up: Deconvolutional Capsule
deconv_cap_1_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=32, strides=1,
padding='same', routings=1, name='deconv_cap_1_2')(up_1)

# Layer 2 Up: Deconvolutional Capsule
deconv_cap_2_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=4, num_atoms=16, upsamp_type='deconv',
scaling=2, padding='same', routings=3,
name='deconv_cap_2_1')(deconv_cap_1_2)

# Skip connection
up_2 = layers.Concatenate(axis=-2, name='up_2')([deconv_cap_2_1, conv_cap_2_1])

# Layer 2 Up: Deconvolutional Capsule
deconv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=16, strides=1,
padding='same', routings=1, name='deconv_cap_2_2')(up_2)

# Layer 3 Up: Deconvolutional Capsule
deconv_cap_3_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=2, num_atoms=16, upsamp_type='deconv',
scaling=2, padding='same', routings=3,
name='deconv_cap_3_1')(deconv_cap_2_2)

# Skip connection
up_3 = layers.Concatenate(axis=-2, name='up_3')([deconv_cap_3_1, conv1_reshaped])

# Layer 4: Convolutional Capsule: 1x1
seg_caps = ConvCapsuleLayer(kernel_size=1, num_capsule=1, num_atoms=16, strides=1, padding='same',
routings=1, name='seg_caps')(up_3)

# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps)

# Decoder network.
_, H, W, C, A = seg_caps.get_shape()
y = layers.Input(shape=input_shape[:-1]+(1,))
masked_by_y = Mask()([seg_caps, y]) # The true label is used to mask the output of capsule layer. For training
masked = Mask()(seg_caps) # Mask using the capsule with maximal length. For prediction

def shared_decoder(mask_layer):
recon_remove_dim = layers.Reshape((H.value, W.value, A.value))(mask_layer)

recon_1 = layers.Conv2D(filters=64, kernel_size=1, padding='same', kernel_initializer='he_normal',
activation='relu', name='recon_1')(recon_remove_dim)

recon_2 = layers.Conv2D(filters=128, kernel_size=1, padding='same', kernel_initializer='he_normal',
activation='relu', name='recon_2')(recon_1)

out_recon = layers.Conv2D(filters=1, kernel_size=1, padding='same', kernel_initializer='he_normal',
activation='sigmoid', name='out_recon')(recon_2)

return out_recon

# Models for training and evaluation (prediction)
train_model = models.Model(inputs=[x, y], outputs=[out_seg, shared_decoder(masked_by_y)])
eval_model = models.Model(inputs=x, outputs=[out_seg, shared_decoder(masked)])

# manipulate model
noise = layers.Input(shape=((H.value, W.value, C.value, A.value)))
noised_seg_caps = layers.Add()([seg_caps, noise])
masked_noised_y = Mask()([noised_seg_caps, y])
manipulate_model = models.Model(inputs=[x, y, noise], outputs=shared_decoder(masked_noised_y))

return train_model, eval_model, manipulate_model


def CapsNetBasic(input_shape, n_class=2):
x = layers.Input(shape=input_shape)

# Layer 1: Just a conventional Conv2D layer
conv1 = layers.Conv2D(filters=256, kernel_size=5, strides=1, padding='same', activation='relu', name='conv1')(x)

# Reshape layer to be 1 capsule x [filters] atoms
_, H, W, C = conv1.get_shape()
conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1)

# Layer 1: Primary Capsule: Conv cap with routing 1
primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same',
routings=1, name='primarycaps')(conv1_reshaped)

# Layer 4: Convolutional Capsule: 1x1
seg_caps = ConvCapsuleLayer(kernel_size=1, num_capsule=1, num_atoms=16, strides=1, padding='same',
routings=3, name='seg_caps')(primary_caps)

# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps)

# Decoder network.
_, H, W, C, A = seg_caps.get_shape()
y = layers.Input(shape=input_shape[:-1]+(1,))
masked_by_y = Mask()([seg_caps, y]) # The true label is used to mask the output of capsule layer. For training
masked = Mask()(seg_caps) # Mask using the capsule with maximal length. For prediction

def shared_decoder(mask_layer):
recon_remove_dim = layers.Reshape((H.value, W.value, A.value))(mask_layer)

recon_1 = layers.Conv2D(filters=64, kernel_size=1, padding='same', kernel_initializer='he_normal',
activation='relu', name='recon_1')(recon_remove_dim)

recon_2 = layers.Conv2D(filters=128, kernel_size=1, padding='same', kernel_initializer='he_normal',
activation='relu', name='recon_2')(recon_1)

out_recon = layers.Conv2D(filters=1, kernel_size=1, padding='same', kernel_initializer='he_normal',
activation='sigmoid', name='out_recon')(recon_2)

return out_recon

# Models for training and evaluation (prediction)
train_model = models.Model(inputs=[x, y], outputs=[out_seg, shared_decoder(masked_by_y)])
eval_model = models.Model(inputs=x, outputs=[out_seg, shared_decoder(masked)])

# manipulate model
noise = layers.Input(shape=((H.value, W.value, C.value, A.value)))
noised_seg_caps = layers.Add()([seg_caps, noise])
masked_noised_y = Mask()([noised_seg_caps, y])
manipulate_model = models.Model(inputs=[x, y, noise], outputs=shared_decoder(masked_noised_y))

return train_model, eval_model, manipulate_model
Loading

0 comments on commit 0d35294

Please sign in to comment.