Skip to content

Segmentation models with pretrained backbones. Keras.

License

Notifications You must be signed in to change notification settings

datbkhn/segmentation_models

 
 

Repository files navigation

Segmentation Models

Segmentation models is python library with Neural Networks for Image Segmentation based on Keras (Tensorflow) framework.

The main features of this library are:

  • High level API (just two lines to create NN)
  • 4 models architectures for binary and multi class segmentation (including legendary Unet)
  • 25 available backbones for each architecture
  • All backbones have pre-trained weights for faster and better convergence

Table of Contents

Quick start

Since the library is built on the Keras framework, created segmentaion model is just a Keras Model, which can be created as easy as:

from segmentation_models import Unet

model = Unet()

Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:

model = Unet('resnet34', encoder_weights='imagenet')

Change number of output classes in the model (choose your case):

# binary segmentation (this parameters are default when you call Unet('resnet34')
model = Unet('resnet34', classes=1, activation='sigmoid')
# multiclass segmentation with non overlapping class masks (your classes + background)
model = Unet('resnet34', classes=3, activation='softmax')
# multiclass segmentation with independent overlapping/non-overlapping class masks
model = Unet('resnet34', classes=3, activation='sigmoid')

Change input shape of the model:

# if you set input channels not equal to 3, you have to set encoder_weights=None
# how to handle such case with encoder_weights='imagenet' described in docs
model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)

Simple training pipeline

from segmentation_models import Unet
from segmentation_models.backbones import get_preprocessing
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score

BACKBONE = 'resnet34'
preprocess_input = get_preprocessing(BACKBONE)

# load your data
x_train, y_train, x_val, y_val = load_data(...)

# preprocess input
x_train = preprocess_input(x_train)
x_val = preprocess_input(x_val)

# define model
model = Unet(BACKBONE, encoder_weights='imagenet')
model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score])

# fit model
# if you use data generator use model.fit_generator(...) instead of model.fit(...)
# more about `fit_generator` here: https://keras.io/models/sequential/#fit_generator
model.fit(
    x=x_train,
    y=y_train,
    batch_size=16,
    epochs=100,
    validation_data=(x_val, y_val),
)

Same manimulations can be done with Linknet, PSPNet and FPN. For more detailed information about models API and use cases Read the Docs.

Models and Backbones

Models

Unet Linknet
unet_image linknet_image
PSPNet FPN
psp_image fpn_image

Backbones

Type Names
VGG 'vgg16' 'vgg19'
ResNet 'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152'
SE-ResNet 'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152'
ResNeXt 'resnext50' 'resnext101'
SE-ResNeXt 'seresnext50' 'seresnext101'
SENet154 'senet154'
DenseNet 'densenet121' 'densenet169' 'densenet201'
Inception 'inceptionv3' 'inceptionresnetv2'
MobileNet 'mobilenet' 'mobilenetv2'
All backbones have weights trained on 2012 ILSVRC ImageNet dataset (encoder_weights='imagenet').

Installation

Requirements

  1. Python 3.5+
  2. Keras >= 2.2.0
  3. Keras Application >= 1.0.7
  4. Image Classifiers == 0.2.0
  5. Tensorflow 1.9 (tested)

Pip package

$ pip install segmentation-models

Latest version

$ pip install git+https://github.com/qubvel/segmentation_models

Documentation

Latest documentation is avaliable on Read the Docs

Change Log

To see important changes between versions look at CHANGELOG.md

Citing

@misc{Yakubovskiy:2019,
  Author = {Pavel Yakubovskiy},
  Title = {Segmentation Models},
  Year = {2019},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/qubvel/segmentation_models}}
}

License

Project is distributed under MIT Licence.

About

Segmentation models with pretrained backbones. Keras.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%