Skip to content

Commit

Permalink
fast-neural-style example (pytorch#129)
Browse files Browse the repository at this point in the history
* Add fast-neural-style implementation

* Rename directory

* Add option for stylizing using GPU

* Use state_dict for saving and loading model

* Update vgg-model download link

* Add script for downloading models, remove saved-models folder

* Use pytorch's pretrained vgg

* Remove cloning of intermediate outputs

* Add pytorch vgg results, update README.md

* Update README.md

* Change default learning rate

* Update README.md

* Add content scaling in stylize, edit docstring

* Refactor code

* Use inbuilt Instance-Normalization, refactor code

* Fix typo in README.md

* Update README.md

* Update models, photos, README.md

* Refactor

* Change affine and momentum parameters for InstanceNorm

* Change mode back to training, refactor transformer_net

After checkpointing the model remained in evaluation mode and hence no updates were made, add code to put the model back in training mode after checkpointing. Also use Volatile variable when stylizing images during testing

* Refactor

* Update stylized images

* Update candy style image
  • Loading branch information
abhiskk authored and soumith committed Jun 6, 2017
1 parent 5c41070 commit dc10cd8
Show file tree
Hide file tree
Showing 17 changed files with 467 additions and 0 deletions.
57 changes: 57 additions & 0 deletions fast_neural_style/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# fast-neural-style :city_sunrise: :rocket:
This repository contains a pytorch implementation of an algorithm for artistic style transfer. The algorithm can be used to mix the content of an image with the style of another image. For example, here is a photograph of a door arch rendered in the style of a stained glass painting.

The model uses the method described in [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) along with [Instance Normalization](https://arxiv.org/pdf/1607.08022.pdf). The saved-models for examples shown in the README can be downloaded from [here](https://www.dropbox.com/s/lrvwfehqdcxoza8/saved_models.zip?dl=0).

<p align="center">
<img src="images/style-images/mosaic.jpg" height="200px">
<img src="images/content-images/amber.jpg" height="200px">
<img src="images/output-images/amber-mosaic.jpg" height="440px">
</p>

## Requirements
The program is written in Python, and uses [pytorch](http://pytorch.org/), [scipy](https://www.scipy.org). A GPU is not necessary, but can provide a significant speed up especially for training a new model. Regular sized images can be styled on a laptop or desktop using saved models.

## Usage
Stylize image
```
python neural_style/neural_style.py eval --content-image </path/to/content/image> --model </path/to/saved/model> --output-image </path/to/output/image> --cuda 0
```
* `--content-image`: path to content image you want to stylize.
* `--model`: saved model to be used for stylizing the image (eg: `mosaic.pth`)
* `--output-image`: path for saving the output image.
* `--content-scale`: factor for scaling down the content image if memory is an issue (eg: value of 2 will halve the height and width of content-image)
* `--cuda`: set it to 1 for running on GPU, 0 for CPU.

Train model
```bash
python neural_style/neural_style.py train --dataset </path/to/train-dataset> --style-image </path/to/style/image> --save-model-dir </path/to/save-model/folder> --epochs 2 --cuda 1
```

There are several command line arguments, the important ones are listed below
* `--dataset`: path to training dataset, the path should point to a folder containing another folder with all the training images. I used COCO 2014 Training images dataset [80K/13GB] [(download)](http://mscoco.org/dataset/#download).
* `--style-image`: path to style-image.
* `--save-model-dir`: path to folder where trained model will be saved.
* `--cuda`: set it to 1 for running on GPU, 0 for CPU.

Refer to ``neural_style/neural_style.py`` for other command line arguments. For training new models you might have to tune the values of `--content-weight` and `--style-weight`. The mosaic style model shown above was trained with `--content-weight 1e5` and `--style-weight 1e10`. The remaining 3 models were also trained with similar order of weight parameters with slight variation in the `--style-weight` (`5e10` or `1e11`).

## Models

Models for the examples shown below can be downloaded from [here](https://www.dropbox.com/s/lrvwfehqdcxoza8/saved_models.zip?dl=0) or by running the script ``download_saved_models.sh``.

<div align='center'>
<img src='images/content-images/amber.jpg' height="174px">
</div>

<div align='center'>
<img src='images/style-images/mosaic.jpg' height="174px">
<img src='images/output-images/amber-mosaic.jpg' height="174px">
<img src='images/output-images/amber-candy.jpg' height="174px">
<img src='images/style-images/candy.jpg' height="174px">
<br>
<img src='images/style-images/rain-princess-cropped.jpg' height="174px">
<img src='images/output-images/amber-rain-princess.jpg' height="174px">
<img src='images/output-images/amber-udnie.jpg' height="174px">
<img src='images/style-images/udnie.jpg' height="174px">
</div>
2 changes: 2 additions & 0 deletions fast_neural_style/download_saved_models.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
wget https://www.dropbox.com/s/lrvwfehqdcxoza8/saved_models.zip?dl=1
unzip saved_models.zip?dl=1
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added fast_neural_style/images/style-images/candy.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added fast_neural_style/images/style-images/mosaic.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added fast_neural_style/images/style-images/udnie.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
226 changes: 226 additions & 0 deletions fast_neural_style/neural_style/neural_style.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import argparse
import os
import sys
import time

import numpy as np
import torch
from torch.autograd import Variable
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

import utils
from transformer_net import TransformerNet
from vgg import Vgg16


def check_paths(args):
try:
if not os.path.exists(args.save_model_dir):
os.makedirs(args.save_model_dir)
if args.checkpoint_model_dir is not None and not (os.path.exists(args.checkpoint_model_dir)):
os.makedirs(args.checkpoint_model_dir)
except OSError as e:
print(e)
sys.exit(1)


def train(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)

if args.cuda:
torch.cuda.manual_seed(args.seed)

transform = transforms.Compose([
transforms.Scale(args.image_size),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
train_dataset = datasets.ImageFolder(args.dataset, transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

transformer = TransformerNet()
optimizer = Adam(transformer.parameters(), args.lr)
mse_loss = torch.nn.MSELoss()

vgg = Vgg16(requires_grad=False)
style_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
style = utils.load_image(args.style_image, size=args.style_size)
style = style_transform(style)
style = style.repeat(args.batch_size, 1, 1, 1)

if args.cuda:
transformer.cuda()
vgg.cuda()
style = style.cuda()

style_v = Variable(style)
style_v = utils.normalize_batch(style_v)
features_style = vgg(style_v)
gram_style = [utils.gram_matrix(y) for y in features_style]

for e in range(args.epochs):
transformer.train()
agg_content_loss = 0.
agg_style_loss = 0.
count = 0
for batch_id, (x, _) in enumerate(train_loader):
n_batch = len(x)
count += n_batch
optimizer.zero_grad()
x = Variable(x)
if args.cuda:
x = x.cuda()

y = transformer(x)

y = utils.normalize_batch(y)
x = utils.normalize_batch(x)

features_y = vgg(y)
features_x = vgg(x)

content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

style_loss = 0.
for ft_y, gm_s in zip(features_y, gram_style):
gm_y = utils.gram_matrix(ft_y)
style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
style_loss *= args.style_weight

total_loss = content_loss + style_loss
total_loss.backward()
optimizer.step()

agg_content_loss += content_loss.data[0]
agg_style_loss += style_loss.data[0]

if (batch_id + 1) % args.log_interval == 0:
mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
time.ctime(), e + 1, count, len(train_dataset),
agg_content_loss / (batch_id + 1),
agg_style_loss / (batch_id + 1),
(agg_content_loss + agg_style_loss) / (batch_id + 1)
)
print(mesg)

if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
transformer.eval()
if args.cuda:
transformer.cpu()
ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
torch.save(transformer.state_dict(), ckpt_model_path)
if args.cuda:
transformer.cuda()
transformer.train()

# save model
transformer.eval()
if args.cuda:
transformer.cpu()
save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
args.content_weight) + "_" + str(args.style_weight) + ".model"
save_model_path = os.path.join(args.save_model_dir, save_model_filename)
torch.save(transformer.state_dict(), save_model_path)

print("\nDone, trained model saved at", save_model_path)


def stylize(args):
content_image = utils.load_image(args.content_image, scale=args.content_scale)
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0)
if args.cuda:
content_image = content_image.cuda()
content_image = Variable(content_image, volatile=True)

style_model = TransformerNet()
style_model.load_state_dict(torch.load(args.model))
if args.cuda:
style_model.cuda()
output = style_model(content_image)
if args.cuda:
output = output.cpu()
output_data = output.data[0]
utils.save_image(args.output_image, output_data)


def main():
main_arg_parser = argparse.ArgumentParser(description="parser for fast-neural-style")
subparsers = main_arg_parser.add_subparsers(title="subcommands", dest="subcommand")

train_arg_parser = subparsers.add_parser("train", help="parser for training arguments")
train_arg_parser.add_argument("--epochs", type=int, default=2,
help="number of training epochs, default is 2")
train_arg_parser.add_argument("--batch-size", type=int, default=4,
help="batch size for training, default is 4")
train_arg_parser.add_argument("--dataset", type=str, required=True,
help="path to training dataset, the path should point to a folder "
"containing another folder with all the training images")
train_arg_parser.add_argument("--style-image", type=str, default="images/style-images/mosaic.jpg",
help="path to style-image")
train_arg_parser.add_argument("--save-model-dir", type=str, required=True,
help="path to folder where trained model will be saved.")
train_arg_parser.add_argument("--checkpoint-model-dir", type=str, default=None,
help="path to folder where checkpoints of trained models will be saved")
train_arg_parser.add_argument("--image-size", type=int, default=256,
help="size of training images, default is 256 X 256")
train_arg_parser.add_argument("--style-size", type=int, default=None,
help="size of style-image, default is the original size of style image")
train_arg_parser.add_argument("--cuda", type=int, required=True,
help="set it to 1 for running on GPU, 0 for CPU")
train_arg_parser.add_argument("--seed", type=int, default=42,
help="random seed for training")
train_arg_parser.add_argument("--content-weight", type=float, default=1e5,
help="weight for content-loss, default is 1e5")
train_arg_parser.add_argument("--style-weight", type=float, default=1e10,
help="weight for style-loss, default is 1e10")
train_arg_parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate, default is 1e-3")
train_arg_parser.add_argument("--log-interval", type=int, default=500,
help="number of images after which the training loss is logged, default is 500")
train_arg_parser.add_argument("--checkpoint-interval", type=int, default=2000,
help="number of batches after which a checkpoint of the trained model will be created")

eval_arg_parser = subparsers.add_parser("eval", help="parser for evaluation/stylizing arguments")
eval_arg_parser.add_argument("--content-image", type=str, required=True,
help="path to content image you want to stylize")
eval_arg_parser.add_argument("--content-scale", type=float, default=None,
help="factor for scaling down the content image")
eval_arg_parser.add_argument("--output-image", type=str, required=True,
help="path for saving the output image")
eval_arg_parser.add_argument("--model", type=str, required=True,
help="saved model to be used for stylizing the image")
eval_arg_parser.add_argument("--cuda", type=int, required=True,
help="set it to 1 for running on GPU, 0 for CPU")

args = main_arg_parser.parse_args()

if args.subcommand is None:
print("ERROR: specify either train or eval")
sys.exit(1)
if args.cuda and not torch.cuda.is_available():
print("ERROR: cuda is not available, try running on CPU")
sys.exit(1)

if args.subcommand == "train":
check_paths(args)
train(args)
else:
stylize(args)


if __name__ == "__main__":
main()
101 changes: 101 additions & 0 deletions fast_neural_style/neural_style/transformer_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch


class TransformerNet(torch.nn.Module):
def __init__(self):
super(TransformerNet, self).__init__()
# Initial convolution layers
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
# Residual layers
self.res1 = ResidualBlock(128)
self.res2 = ResidualBlock(128)
self.res3 = ResidualBlock(128)
self.res4 = ResidualBlock(128)
self.res5 = ResidualBlock(128)
# Upsampling Layers
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
# Non-linearities
self.relu = torch.nn.ReLU()

def forward(self, X):
y = self.relu(self.in1(self.conv1(X)))
y = self.relu(self.in2(self.conv2(y)))
y = self.relu(self.in3(self.conv3(y)))
y = self.res1(y)
y = self.res2(y)
y = self.res3(y)
y = self.res4(y)
y = self.res5(y)
y = self.relu(self.in4(self.deconv1(y)))
y = self.relu(self.in5(self.deconv2(y)))
y = self.deconv3(y)
return y


class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvLayer, self).__init__()
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out


class ResidualBlock(torch.nn.Module):
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
"""

def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
residual = x
out = self.relu(self.in1(self.conv1(x)))
out = self.in2(self.conv2(out))
out = out + residual
return out


class UpsampleConvLayer(torch.nn.Module):
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
ref: http://distill.pub/2016/deconv-checkerboard/
"""

def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
if upsample:
self.upsample_layer = torch.nn.UpsamplingNearest2d(scale_factor=upsample)
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

def forward(self, x):
x_in = x
if self.upsample:
x_in = self.upsample_layer(x_in)
out = self.reflection_pad(x_in)
out = self.conv2d(out)
return out
Loading

0 comments on commit dc10cd8

Please sign in to comment.