forked from qubvel/segmentation_models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request qubvel#7 from qubvel/feature-linknet
linknet implementation
- Loading branch information
Showing
5 changed files
with
306 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
name = "segmentation_models" | ||
|
||
from .unet import Unet | ||
from .fpn import FPN | ||
from .fpn import FPN | ||
from .linknet import Linknet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .model import Linknet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import keras.backend as K | ||
from keras.layers import Conv2DTranspose as Transpose | ||
from keras.layers import UpSampling2D | ||
from keras.layers import Conv2D | ||
from keras.layers import BatchNormalization | ||
from keras.layers import Activation | ||
from keras.layers import Add | ||
|
||
|
||
def handle_block_names(stage): | ||
conv_name = 'decoder_stage{}_conv'.format(stage) | ||
bn_name = 'decoder_stage{}_bn'.format(stage) | ||
relu_name = 'decoder_stage{}_relu'.format(stage) | ||
up_name = 'decoder_stage{}_upsample'.format(stage) | ||
return conv_name, bn_name, relu_name, up_name | ||
|
||
|
||
def ConvRelu(filters, | ||
kernel_size, | ||
use_batchnorm=False, | ||
conv_name='conv', | ||
bn_name='bn', | ||
relu_name='relu'): | ||
|
||
def layer(x): | ||
|
||
x = Conv2D(filters, | ||
kernel_size, | ||
padding="same", | ||
name=conv_name, | ||
use_bias=not(use_batchnorm))(x) | ||
|
||
if use_batchnorm: | ||
x = BatchNormalization(name=bn_name)(x) | ||
|
||
x = Activation('relu', name=relu_name)(x) | ||
|
||
return x | ||
return layer | ||
|
||
|
||
def Conv2DUpsample(filters, | ||
upsample_rate, | ||
kernel_size=(3,3), | ||
up_name='up', | ||
conv_name='conv', | ||
**kwargs): | ||
|
||
def layer(input_tensor): | ||
x = UpSampling2D(upsample_rate, name=up_name)(input_tensor) | ||
x = Conv2D(filters, | ||
kernel_size, | ||
padding='same', | ||
name=conv_name, | ||
**kwargs)(x) | ||
return x | ||
return layer | ||
|
||
|
||
def Conv2DTranspose(filters, | ||
upsample_rate, | ||
kernel_size=(4,4), | ||
up_name='up', | ||
**kwargs): | ||
|
||
if not tuple(upsample_rate) == (2,2): | ||
raise NotImplementedError( | ||
f'Conv2DTranspose support only upsample_rate=(2, 2), got {upsample_rate}') | ||
|
||
def layer(input_tensor): | ||
x = Transpose(filters, | ||
kernel_size=kernel_size, | ||
strides=upsample_rate, | ||
padding='same', | ||
name=up_name)(input_tensor) | ||
return x | ||
return layer | ||
|
||
|
||
def UpsampleBlock(filters, | ||
upsample_rate, | ||
kernel_size, | ||
use_batchnorm=False, | ||
upsample_layer='upsampling', | ||
conv_name='conv', | ||
bn_name='bn', | ||
relu_name='relu', | ||
up_name='up', | ||
**kwargs): | ||
|
||
if upsample_layer == 'upsampling': | ||
UpBlock = Conv2DUpsample | ||
|
||
elif upsample_layer == 'transpose': | ||
UpBlock = Conv2DTranspose | ||
|
||
else: | ||
raise ValueError(f'Not supported up layer type {upsample_layer}') | ||
|
||
def layer(input_tensor): | ||
|
||
x = UpBlock(filters, | ||
upsample_rate=upsample_rate, | ||
kernel_size=kernel_size, | ||
use_bias=not(use_batchnorm), | ||
conv_name=conv_name, | ||
up_name=up_name, | ||
**kwargs)(input_tensor) | ||
|
||
if use_batchnorm: | ||
x = BatchNormalization(name=bn_name)(x) | ||
|
||
x = Activation('relu', name=relu_name)(x) | ||
|
||
return x | ||
return layer | ||
|
||
|
||
def DecoderBlock(stage, | ||
filters=None, | ||
kernel_size=(3,3), | ||
upsample_rate=(2,2), | ||
use_batchnorm=False, | ||
skip=None, | ||
upsample_layer='upsampling'): | ||
|
||
def layer(input_tensor): | ||
|
||
conv_name, bn_name, relu_name, up_name = handle_block_names(stage) | ||
input_filters = K.int_shape(input_tensor)[-1] | ||
|
||
if skip is not None: | ||
output_filters = K.int_shape(skip)[-1] | ||
else: | ||
output_filters = filters | ||
|
||
x = ConvRelu(input_filters // 4, | ||
kernel_size=(1, 1), | ||
use_batchnorm=use_batchnorm, | ||
conv_name=conv_name + '1', | ||
bn_name=bn_name + '1', | ||
relu_name=relu_name + '1')(input_tensor) | ||
|
||
x = UpsampleBlock(filters=input_filters // 4, | ||
kernel_size=kernel_size, | ||
upsample_layer=upsample_layer, | ||
upsample_rate=upsample_rate, | ||
use_batchnorm=use_batchnorm, | ||
conv_name=conv_name + '2', | ||
bn_name=bn_name + '2', | ||
up_name=up_name + '2', | ||
relu_name=relu_name + '2')(x) | ||
|
||
x = ConvRelu(output_filters, | ||
kernel_size=(1, 1), | ||
use_batchnorm=use_batchnorm, | ||
conv_name=conv_name + '3', | ||
bn_name=bn_name + '3', | ||
relu_name=relu_name + '3')(x) | ||
|
||
if skip is not None: | ||
x = Add()([x, skip]) | ||
|
||
return x | ||
return layer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from keras.layers import Conv2D | ||
from keras.layers import Activation | ||
from keras.models import Model | ||
|
||
from .blocks import DecoderBlock | ||
from ..utils import get_layer_number, to_tuple | ||
|
||
|
||
def build_linknet(backbone, | ||
classes, | ||
skip_connection_layers, | ||
decoder_filters=(None, None, None, None, 16), | ||
upsample_rates=(2, 2, 2, 2, 2), | ||
n_upsample_blocks=5, | ||
upsample_kernel_size=(3, 3), | ||
upsample_layer='upsampling', | ||
activation='sigmoid', | ||
use_batchnorm=False): | ||
|
||
input = backbone.input | ||
x = backbone.output | ||
|
||
# convert layer names to indices | ||
skip_connection_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l | ||
for l in skip_connection_layers]) | ||
|
||
for i in range(n_upsample_blocks): | ||
|
||
# check if there is a skip connection | ||
skip_connection = None | ||
if i < len(skip_connection_idx): | ||
skip_connection = backbone.layers[skip_connection_idx[i]].output | ||
|
||
upsample_rate = to_tuple(upsample_rates[i]) | ||
|
||
x = DecoderBlock(stage=i, | ||
filters=decoder_filters[i], | ||
kernel_size=upsample_kernel_size, | ||
upsample_rate=upsample_rate, | ||
use_batchnorm=use_batchnorm, | ||
upsample_layer=upsample_layer, | ||
skip=skip_connection)(x) | ||
|
||
x = Conv2D(classes, (3, 3), padding='same', name='final_conv')(x) | ||
x = Activation(activation, name=activation)(x) | ||
|
||
model = Model(input, x) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from .builder import build_linknet | ||
from ..utils import freeze_model | ||
from ..backbones import get_backbone | ||
|
||
|
||
DEFAULT_SKIP_CONNECTIONS = { | ||
'vgg16': ('block5_conv3', 'block4_conv3', 'block3_conv3', 'block2_conv2'), | ||
'vgg19': ('block5_conv4', 'block4_conv4', 'block3_conv4', 'block2_conv2'), | ||
'resnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), | ||
'resnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), | ||
'resnet50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), | ||
'resnet101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), | ||
'resnet152': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), | ||
'resnext50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), | ||
'resnext101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), | ||
'inceptionv3': (228, 86, 16, 9), | ||
'inceptionresnetv2': (594, 260, 16, 9), | ||
'densenet121': (311, 139, 51, 4), | ||
'densenet169': (367, 139, 51, 4), | ||
'densenet201': (479, 139, 51, 4), | ||
} | ||
|
||
|
||
def Linknet(backbone_name='vgg16', | ||
input_shape=(None, None, 3), | ||
input_tensor=None, | ||
encoder_weights='imagenet', | ||
freeze_encoder=False, | ||
skip_connections='default', | ||
n_upsample_blocks=5, | ||
decoder_filters=(None, None, None, None, 16), | ||
decoder_use_batchnorm=False, | ||
upsample_layer='upsampling', | ||
upsample_kernel_size=(3, 3), | ||
classes=1, | ||
activation='sigmoid'): | ||
""" | ||
Version of Linkent model (https://arxiv.org/pdf/1707.03718.pdf) | ||
This implementation by default has 4 skip connection links (original - 3). | ||
Args: | ||
backbone_name: (str) look at list of available backbones. | ||
input_shape: (tuple) dimensions of input data (H, W, C) | ||
input_tensor: keras tensor | ||
encoder_weights: one of `None` (random initialization), 'imagenet' (pre-training on ImageNet) | ||
freeze_encoder: (bool) Set encoder layers weights as non-trainable. Useful for fine-tuning | ||
skip_connections: if 'default' is used take default skip connections, | ||
decoder_filters: (tuple of int) a number of convolution filters in decoder blocks, | ||
for block with skip connection a number of filters is equal to number of filters in | ||
corresponding encoder block (estimates automatically and can be passed as `None` value). | ||
decoder_use_batchnorm: (bool) if True add batch normalisation layer between `Conv2D` ad `Activation` layers | ||
n_upsample_blocks: (int) a number of upsampling blocks in decoder | ||
upsample_layer: (str) one of 'upsampling' and 'transpose' | ||
upsample_kernel_size: (tuple of int) convolution kernel size in upsampling block | ||
classes: (int) a number of classes for output | ||
activation: (str) one of keras activations | ||
Returns: | ||
model: instance of Keras Model | ||
""" | ||
|
||
backbone = get_backbone(backbone_name, | ||
input_shape=input_shape, | ||
input_tensor=input_tensor, | ||
weights=encoder_weights, | ||
include_top=False) | ||
|
||
if skip_connections == 'default': | ||
skip_connections = DEFAULT_SKIP_CONNECTIONS[backbone_name] | ||
|
||
model = build_linknet(backbone, | ||
classes, | ||
skip_connections, | ||
decoder_filters=decoder_filters, | ||
upsample_layer=upsample_layer, | ||
activation=activation, | ||
n_upsample_blocks=n_upsample_blocks, | ||
upsample_rates=(2, 2, 2, 2, 2), | ||
upsample_kernel_size=upsample_kernel_size, | ||
use_batchnorm=decoder_use_batchnorm) | ||
|
||
# lock encoder weights for fine-tuning | ||
if freeze_encoder: | ||
freeze_model(backbone) | ||
|
||
model.name = 'link-{}'.format(backbone_name) | ||
|
||
return model |