Skip to content

Commit

Permalink
Merge pull request tensorflow#35 from daviddao/master
Browse files Browse the repository at this point in the history
Adding Spatial Transformer to Models
  • Loading branch information
martinwicke committed Apr 1, 2016
2 parents 3000453 + 41c52d6 commit 6406a15
Show file tree
Hide file tree
Showing 7 changed files with 625 additions and 0 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
# The email address is not required for organizations.

Google Inc.
David Dao <[email protected]>
64 changes: 64 additions & 0 deletions transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Spatial Transformer Network

The Spatial Transformer Network [1] allows the spatial manipulation of data within the network.

<div align="center">
<img width="600px" src="http://i.imgur.com/ExGDVul.png"><br><br>
</div>

### API

A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2].

#### How to use

<div align="center">
<img src="http://i.imgur.com/gfqLV3f.png"><br><br>
</div>

```python
transformer(U, theta, downsample_factor=1)
```

#### Parameters

U : float
The output of a convolutional net should have the
shape [num_batch, height, width, num_channels].
theta: float
The output of the
localisation network should be [num_batch, 6].
downsample_factor : float
A value of 1 will keep the original size of the image
Values larger than 1 will downsample the image.
Values below 1 will upsample the image
example image: height = 100, width = 200
downsample_factor = 2
output image will then be 50, 100

#### Notes
To initialize the network to the identity transform init ``theta`` to :

```python
identity = np.array([[1., 0., 0.],
[0., 1., 0.]])
identity = identity.flatten()
theta = tf.Variable(initial_value=identity)
```

#### Experiments

<div align="center">
<img width="600px" src="http://i.imgur.com/HtCBYk2.png"><br><br>
</div>

We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN.

All experiments were run in Tensorflow 0.7.

### References

[1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015)

[2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
172 changes: 172 additions & 0 deletions transformer/cluttered_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from spatial_transformer import transformer
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from tf_utils import conv2d, linear, weight_variable, bias_variable, dense_to_one_hot

# %% Load data
mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')

X_train = mnist_cluttered['X_train']
y_train = mnist_cluttered['y_train']
X_valid = mnist_cluttered['X_valid']
y_valid = mnist_cluttered['y_valid']
X_test = mnist_cluttered['X_test']
y_test = mnist_cluttered['y_test']

# % turn from dense to one hot representation
Y_train = dense_to_one_hot(y_train, n_classes=10)
Y_valid = dense_to_one_hot(y_valid, n_classes=10)
Y_test = dense_to_one_hot(y_test, n_classes=10)

# %% Graph representation of our network

# %% Placeholders for 40x40 resolution
x = tf.placeholder(tf.float32, [None, 1600])
y = tf.placeholder(tf.float32, [None, 10])

# %% Since x is currently [batch, height*width], we need to reshape to a
# 4-D tensor to use it in a convolutional graph. If one component of
# `shape` is the special value -1, the size of that dimension is
# computed so that the total size remains constant. Since we haven't
# defined the batch dimension's shape yet, we use -1 to denote this
# dimension should not change size.
x_tensor = tf.reshape(x, [-1, 40, 40, 1])

# %% We'll setup the two-layer localisation network to figure out the parameters for an affine transformation of the input
# %% Create variables for fully connected layer
W_fc_loc1 = weight_variable([1600, 20])
b_fc_loc1 = bias_variable([20])

W_fc_loc2 = weight_variable([20, 6])
initial = np.array([[1.,0, 0],[0,1.,0]]) # Use identity transformation as starting point
initial = initial.astype('float32')
initial = initial.flatten()
b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')

# %% Define the two layer localisation network
h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)
# %% We can add dropout for regularizing and to reduce overfitting like so:
keep_prob = tf.placeholder(tf.float32)
h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob)
# %% Second layer
h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)

# %% We'll create a spatial transformer module to identify discriminative patches
h_trans = transformer(x_tensor, h_fc_loc2, downsample_factor=1)

# %% We'll setup the first convolutional layer
# Weight matrix is [height x width x input_channels x output_channels]
filter_size = 3
n_filters_1 = 16
W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1])

# %% Bias is [output_channels]
b_conv1 = bias_variable([n_filters_1])

# %% Now we can build a graph which does the first layer of convolution:
# we define our stride as batch x height x width x channels
# instead of pooling, we use strides of 2 and more layers
# with smaller filters.

h_conv1 = tf.nn.relu(
tf.nn.conv2d(input=h_trans,
filter=W_conv1,
strides=[1, 2, 2, 1],
padding='SAME') +
b_conv1)

# %% And just like the first layer, add additional layers to create
# a deep net
n_filters_2 = 16
W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2])
b_conv2 = bias_variable([n_filters_2])
h_conv2 = tf.nn.relu(
tf.nn.conv2d(input=h_conv1,
filter=W_conv2,
strides=[1, 2, 2, 1],
padding='SAME') +
b_conv2)

# %% We'll now reshape so we can connect to a fully-connected layer:
h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2])

# %% Create a fully-connected layer:
n_fc = 1024
W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc])
b_fc1 = bias_variable([n_fc])
h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)

h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# %% And finally our softmax layer:
W_fc2 = weight_variable([n_fc, 10])
b_fc2 = bias_variable([10])
y_pred = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# %% Define loss/eval/training functions
cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))
opt = tf.train.AdamOptimizer()
optimizer = opt.minimize(cross_entropy)
grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])

# %% Monitor accuracy
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))

# %% We now create a new session to actually perform the initialization the
# variables:
sess = tf.Session()
sess.run(tf.initialize_all_variables())


# %% We'll now train in minibatches and report accuracy, loss:
iter_per_epoch = 100
n_epochs = 500
train_size = 10000

indices = np.linspace(0,10000 - 1,iter_per_epoch)
indices = indices.astype('int')

for epoch_i in range(n_epochs):
for iter_i in range(iter_per_epoch - 1):
batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]

if iter_i % 10 == 0:
loss = sess.run(cross_entropy,
feed_dict={
x: batch_xs,
y: batch_ys,
keep_prob: 1.0
})
print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))

sess.run(optimizer, feed_dict={
x: batch_xs, y: batch_ys, keep_prob: 0.8})


print('Accuracy: ' + str(sess.run(accuracy,
feed_dict={
x: X_valid,
y: Y_valid,
keep_prob: 1.0
})))
#theta = sess.run(h_fc_loc2, feed_dict={
# x: batch_xs, keep_prob: 1.0})
#print(theta[0])
20 changes: 20 additions & 0 deletions transformer/data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
### How to get the data

#### Cluttered MNIST

The cluttered MNIST dataset can be found here [1] or can be generated via [2].

Settings used for `cluttered_mnist.py` :

```python

ORG_SHP = [28, 28]
OUT_SHP = [40, 40]
NUM_DISTORTIONS = 8
dist_size = (5, 5)

```

[1] https://github.com/daviddao/spatial-transformer-tensorflow

[2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py
58 changes: 58 additions & 0 deletions transformer/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from spatial_transformer import transformer
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from tf_utils import conv2d, linear, weight_variable, bias_variable

# %% Create a batch of three images (1600 x 1200)
# %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
im = ndimage.imread('cat.jpg')
im = im / 255.
im = im.reshape(1, 1200, 1600, 3)
im = im.astype('float32')

# %% Simulate batch
batch = np.append(im, im, axis=0)
batch = np.append(batch, im, axis=0)
num_batch = 3

x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
x = tf.cast(batch,'float32')

# %% Create localisation network and convolutional layer
with tf.variable_scope('spatial_transformer_0'):

# %% Create a fully-connected layer with 6 output nodes
n_fc = 6
W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')

# %% Zoom into the image
initial = np.array([[0.5,0, 0],[0,0.5,0]])
initial = initial.astype('float32')
initial = initial.flatten()

b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
h_trans = transformer(x, h_fc1, downsample_factor=2)

# %% Run session
sess = tf.Session()
sess.run(tf.initialize_all_variables())
y = sess.run(h_trans, feed_dict={x: batch})

# plt.imshow(y[0])
Loading

0 comments on commit 6406a15

Please sign in to comment.