Skip to content

Commit

Permalink
Merge pull request qubvel#7 from qubvel/feature-linknet
Browse files Browse the repository at this point in the history
linknet implementation
  • Loading branch information
qubvel authored Sep 9, 2018
2 parents 35ee153 + 2c46e8c commit 9f4ff78
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 1 deletion.
3 changes: 2 additions & 1 deletion segmentation_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name = "segmentation_models"

from .unet import Unet
from .fpn import FPN
from .fpn import FPN
from .linknet import Linknet
1 change: 1 addition & 0 deletions segmentation_models/linknet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import Linknet
165 changes: 165 additions & 0 deletions segmentation_models/linknet/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import keras.backend as K
from keras.layers import Conv2DTranspose as Transpose
from keras.layers import UpSampling2D
from keras.layers import Conv2D
from keras.layers import BatchNormalization
from keras.layers import Activation
from keras.layers import Add


def handle_block_names(stage):
conv_name = 'decoder_stage{}_conv'.format(stage)
bn_name = 'decoder_stage{}_bn'.format(stage)
relu_name = 'decoder_stage{}_relu'.format(stage)
up_name = 'decoder_stage{}_upsample'.format(stage)
return conv_name, bn_name, relu_name, up_name


def ConvRelu(filters,
kernel_size,
use_batchnorm=False,
conv_name='conv',
bn_name='bn',
relu_name='relu'):

def layer(x):

x = Conv2D(filters,
kernel_size,
padding="same",
name=conv_name,
use_bias=not(use_batchnorm))(x)

if use_batchnorm:
x = BatchNormalization(name=bn_name)(x)

x = Activation('relu', name=relu_name)(x)

return x
return layer


def Conv2DUpsample(filters,
upsample_rate,
kernel_size=(3,3),
up_name='up',
conv_name='conv',
**kwargs):

def layer(input_tensor):
x = UpSampling2D(upsample_rate, name=up_name)(input_tensor)
x = Conv2D(filters,
kernel_size,
padding='same',
name=conv_name,
**kwargs)(x)
return x
return layer


def Conv2DTranspose(filters,
upsample_rate,
kernel_size=(4,4),
up_name='up',
**kwargs):

if not tuple(upsample_rate) == (2,2):
raise NotImplementedError(
f'Conv2DTranspose support only upsample_rate=(2, 2), got {upsample_rate}')

def layer(input_tensor):
x = Transpose(filters,
kernel_size=kernel_size,
strides=upsample_rate,
padding='same',
name=up_name)(input_tensor)
return x
return layer


def UpsampleBlock(filters,
upsample_rate,
kernel_size,
use_batchnorm=False,
upsample_layer='upsampling',
conv_name='conv',
bn_name='bn',
relu_name='relu',
up_name='up',
**kwargs):

if upsample_layer == 'upsampling':
UpBlock = Conv2DUpsample

elif upsample_layer == 'transpose':
UpBlock = Conv2DTranspose

else:
raise ValueError(f'Not supported up layer type {upsample_layer}')

def layer(input_tensor):

x = UpBlock(filters,
upsample_rate=upsample_rate,
kernel_size=kernel_size,
use_bias=not(use_batchnorm),
conv_name=conv_name,
up_name=up_name,
**kwargs)(input_tensor)

if use_batchnorm:
x = BatchNormalization(name=bn_name)(x)

x = Activation('relu', name=relu_name)(x)

return x
return layer


def DecoderBlock(stage,
filters=None,
kernel_size=(3,3),
upsample_rate=(2,2),
use_batchnorm=False,
skip=None,
upsample_layer='upsampling'):

def layer(input_tensor):

conv_name, bn_name, relu_name, up_name = handle_block_names(stage)
input_filters = K.int_shape(input_tensor)[-1]

if skip is not None:
output_filters = K.int_shape(skip)[-1]
else:
output_filters = filters

x = ConvRelu(input_filters // 4,
kernel_size=(1, 1),
use_batchnorm=use_batchnorm,
conv_name=conv_name + '1',
bn_name=bn_name + '1',
relu_name=relu_name + '1')(input_tensor)

x = UpsampleBlock(filters=input_filters // 4,
kernel_size=kernel_size,
upsample_layer=upsample_layer,
upsample_rate=upsample_rate,
use_batchnorm=use_batchnorm,
conv_name=conv_name + '2',
bn_name=bn_name + '2',
up_name=up_name + '2',
relu_name=relu_name + '2')(x)

x = ConvRelu(output_filters,
kernel_size=(1, 1),
use_batchnorm=use_batchnorm,
conv_name=conv_name + '3',
bn_name=bn_name + '3',
relu_name=relu_name + '3')(x)

if skip is not None:
x = Add()([x, skip])

return x
return layer
49 changes: 49 additions & 0 deletions segmentation_models/linknet/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from keras.layers import Conv2D
from keras.layers import Activation
from keras.models import Model

from .blocks import DecoderBlock
from ..utils import get_layer_number, to_tuple


def build_linknet(backbone,
classes,
skip_connection_layers,
decoder_filters=(None, None, None, None, 16),
upsample_rates=(2, 2, 2, 2, 2),
n_upsample_blocks=5,
upsample_kernel_size=(3, 3),
upsample_layer='upsampling',
activation='sigmoid',
use_batchnorm=False):

input = backbone.input
x = backbone.output

# convert layer names to indices
skip_connection_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l
for l in skip_connection_layers])

for i in range(n_upsample_blocks):

# check if there is a skip connection
skip_connection = None
if i < len(skip_connection_idx):
skip_connection = backbone.layers[skip_connection_idx[i]].output

upsample_rate = to_tuple(upsample_rates[i])

x = DecoderBlock(stage=i,
filters=decoder_filters[i],
kernel_size=upsample_kernel_size,
upsample_rate=upsample_rate,
use_batchnorm=use_batchnorm,
upsample_layer=upsample_layer,
skip=skip_connection)(x)

x = Conv2D(classes, (3, 3), padding='same', name='final_conv')(x)
x = Activation(activation, name=activation)(x)

model = Model(input, x)

return model
89 changes: 89 additions & 0 deletions segmentation_models/linknet/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from .builder import build_linknet
from ..utils import freeze_model
from ..backbones import get_backbone


DEFAULT_SKIP_CONNECTIONS = {
'vgg16': ('block5_conv3', 'block4_conv3', 'block3_conv3', 'block2_conv2'),
'vgg19': ('block5_conv4', 'block4_conv4', 'block3_conv4', 'block2_conv2'),
'resnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
'resnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
'resnet50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
'resnet101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
'resnet152': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
'resnext50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
'resnext101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
'inceptionv3': (228, 86, 16, 9),
'inceptionresnetv2': (594, 260, 16, 9),
'densenet121': (311, 139, 51, 4),
'densenet169': (367, 139, 51, 4),
'densenet201': (479, 139, 51, 4),
}


def Linknet(backbone_name='vgg16',
input_shape=(None, None, 3),
input_tensor=None,
encoder_weights='imagenet',
freeze_encoder=False,
skip_connections='default',
n_upsample_blocks=5,
decoder_filters=(None, None, None, None, 16),
decoder_use_batchnorm=False,
upsample_layer='upsampling',
upsample_kernel_size=(3, 3),
classes=1,
activation='sigmoid'):
"""
Version of Linkent model (https://arxiv.org/pdf/1707.03718.pdf)
This implementation by default has 4 skip connection links (original - 3).
Args:
backbone_name: (str) look at list of available backbones.
input_shape: (tuple) dimensions of input data (H, W, C)
input_tensor: keras tensor
encoder_weights: one of `None` (random initialization), 'imagenet' (pre-training on ImageNet)
freeze_encoder: (bool) Set encoder layers weights as non-trainable. Useful for fine-tuning
skip_connections: if 'default' is used take default skip connections,
decoder_filters: (tuple of int) a number of convolution filters in decoder blocks,
for block with skip connection a number of filters is equal to number of filters in
corresponding encoder block (estimates automatically and can be passed as `None` value).
decoder_use_batchnorm: (bool) if True add batch normalisation layer between `Conv2D` ad `Activation` layers
n_upsample_blocks: (int) a number of upsampling blocks in decoder
upsample_layer: (str) one of 'upsampling' and 'transpose'
upsample_kernel_size: (tuple of int) convolution kernel size in upsampling block
classes: (int) a number of classes for output
activation: (str) one of keras activations
Returns:
model: instance of Keras Model
"""

backbone = get_backbone(backbone_name,
input_shape=input_shape,
input_tensor=input_tensor,
weights=encoder_weights,
include_top=False)

if skip_connections == 'default':
skip_connections = DEFAULT_SKIP_CONNECTIONS[backbone_name]

model = build_linknet(backbone,
classes,
skip_connections,
decoder_filters=decoder_filters,
upsample_layer=upsample_layer,
activation=activation,
n_upsample_blocks=n_upsample_blocks,
upsample_rates=(2, 2, 2, 2, 2),
upsample_kernel_size=upsample_kernel_size,
use_batchnorm=decoder_use_batchnorm)

# lock encoder weights for fine-tuning
if freeze_encoder:
freeze_model(backbone)

model.name = 'link-{}'.format(backbone_name)

return model

0 comments on commit 9f4ff78

Please sign in to comment.