forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request tensorflow#35 from daviddao/master
Adding Spatial Transformer to Models
- Loading branch information
Showing
7 changed files
with
625 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ | |
# The email address is not required for organizations. | ||
|
||
Google Inc. | ||
David Dao <[email protected]> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Spatial Transformer Network | ||
|
||
The Spatial Transformer Network [1] allows the spatial manipulation of data within the network. | ||
|
||
<div align="center"> | ||
<img width="600px" src="http://i.imgur.com/ExGDVul.png"><br><br> | ||
</div> | ||
|
||
### API | ||
|
||
A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2]. | ||
|
||
#### How to use | ||
|
||
<div align="center"> | ||
<img src="http://i.imgur.com/gfqLV3f.png"><br><br> | ||
</div> | ||
|
||
```python | ||
transformer(U, theta, downsample_factor=1) | ||
``` | ||
|
||
#### Parameters | ||
|
||
U : float | ||
The output of a convolutional net should have the | ||
shape [num_batch, height, width, num_channels]. | ||
theta: float | ||
The output of the | ||
localisation network should be [num_batch, 6]. | ||
downsample_factor : float | ||
A value of 1 will keep the original size of the image | ||
Values larger than 1 will downsample the image. | ||
Values below 1 will upsample the image | ||
example image: height = 100, width = 200 | ||
downsample_factor = 2 | ||
output image will then be 50, 100 | ||
|
||
#### Notes | ||
To initialize the network to the identity transform init ``theta`` to : | ||
|
||
```python | ||
identity = np.array([[1., 0., 0.], | ||
[0., 1., 0.]]) | ||
identity = identity.flatten() | ||
theta = tf.Variable(initial_value=identity) | ||
``` | ||
|
||
#### Experiments | ||
|
||
<div align="center"> | ||
<img width="600px" src="http://i.imgur.com/HtCBYk2.png"><br><br> | ||
</div> | ||
|
||
We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN. | ||
|
||
All experiments were run in Tensorflow 0.7. | ||
|
||
### References | ||
|
||
[1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015) | ||
|
||
[2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
import tensorflow as tf | ||
from spatial_transformer import transformer | ||
from scipy import ndimage | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from tf_utils import conv2d, linear, weight_variable, bias_variable, dense_to_one_hot | ||
|
||
# %% Load data | ||
mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz') | ||
|
||
X_train = mnist_cluttered['X_train'] | ||
y_train = mnist_cluttered['y_train'] | ||
X_valid = mnist_cluttered['X_valid'] | ||
y_valid = mnist_cluttered['y_valid'] | ||
X_test = mnist_cluttered['X_test'] | ||
y_test = mnist_cluttered['y_test'] | ||
|
||
# % turn from dense to one hot representation | ||
Y_train = dense_to_one_hot(y_train, n_classes=10) | ||
Y_valid = dense_to_one_hot(y_valid, n_classes=10) | ||
Y_test = dense_to_one_hot(y_test, n_classes=10) | ||
|
||
# %% Graph representation of our network | ||
|
||
# %% Placeholders for 40x40 resolution | ||
x = tf.placeholder(tf.float32, [None, 1600]) | ||
y = tf.placeholder(tf.float32, [None, 10]) | ||
|
||
# %% Since x is currently [batch, height*width], we need to reshape to a | ||
# 4-D tensor to use it in a convolutional graph. If one component of | ||
# `shape` is the special value -1, the size of that dimension is | ||
# computed so that the total size remains constant. Since we haven't | ||
# defined the batch dimension's shape yet, we use -1 to denote this | ||
# dimension should not change size. | ||
x_tensor = tf.reshape(x, [-1, 40, 40, 1]) | ||
|
||
# %% We'll setup the two-layer localisation network to figure out the parameters for an affine transformation of the input | ||
# %% Create variables for fully connected layer | ||
W_fc_loc1 = weight_variable([1600, 20]) | ||
b_fc_loc1 = bias_variable([20]) | ||
|
||
W_fc_loc2 = weight_variable([20, 6]) | ||
initial = np.array([[1.,0, 0],[0,1.,0]]) # Use identity transformation as starting point | ||
initial = initial.astype('float32') | ||
initial = initial.flatten() | ||
b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2') | ||
|
||
# %% Define the two layer localisation network | ||
h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1) | ||
# %% We can add dropout for regularizing and to reduce overfitting like so: | ||
keep_prob = tf.placeholder(tf.float32) | ||
h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob) | ||
# %% Second layer | ||
h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2) | ||
|
||
# %% We'll create a spatial transformer module to identify discriminative patches | ||
h_trans = transformer(x_tensor, h_fc_loc2, downsample_factor=1) | ||
|
||
# %% We'll setup the first convolutional layer | ||
# Weight matrix is [height x width x input_channels x output_channels] | ||
filter_size = 3 | ||
n_filters_1 = 16 | ||
W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1]) | ||
|
||
# %% Bias is [output_channels] | ||
b_conv1 = bias_variable([n_filters_1]) | ||
|
||
# %% Now we can build a graph which does the first layer of convolution: | ||
# we define our stride as batch x height x width x channels | ||
# instead of pooling, we use strides of 2 and more layers | ||
# with smaller filters. | ||
|
||
h_conv1 = tf.nn.relu( | ||
tf.nn.conv2d(input=h_trans, | ||
filter=W_conv1, | ||
strides=[1, 2, 2, 1], | ||
padding='SAME') + | ||
b_conv1) | ||
|
||
# %% And just like the first layer, add additional layers to create | ||
# a deep net | ||
n_filters_2 = 16 | ||
W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2]) | ||
b_conv2 = bias_variable([n_filters_2]) | ||
h_conv2 = tf.nn.relu( | ||
tf.nn.conv2d(input=h_conv1, | ||
filter=W_conv2, | ||
strides=[1, 2, 2, 1], | ||
padding='SAME') + | ||
b_conv2) | ||
|
||
# %% We'll now reshape so we can connect to a fully-connected layer: | ||
h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2]) | ||
|
||
# %% Create a fully-connected layer: | ||
n_fc = 1024 | ||
W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc]) | ||
b_fc1 = bias_variable([n_fc]) | ||
h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1) | ||
|
||
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) | ||
|
||
# %% And finally our softmax layer: | ||
W_fc2 = weight_variable([n_fc, 10]) | ||
b_fc2 = bias_variable([10]) | ||
y_pred = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) | ||
|
||
# %% Define loss/eval/training functions | ||
cross_entropy = -tf.reduce_sum(y * tf.log(y_pred)) | ||
opt = tf.train.AdamOptimizer() | ||
optimizer = opt.minimize(cross_entropy) | ||
grads = opt.compute_gradients(cross_entropy, [b_fc_loc2]) | ||
|
||
# %% Monitor accuracy | ||
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1)) | ||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float')) | ||
|
||
# %% We now create a new session to actually perform the initialization the | ||
# variables: | ||
sess = tf.Session() | ||
sess.run(tf.initialize_all_variables()) | ||
|
||
|
||
# %% We'll now train in minibatches and report accuracy, loss: | ||
iter_per_epoch = 100 | ||
n_epochs = 500 | ||
train_size = 10000 | ||
|
||
indices = np.linspace(0,10000 - 1,iter_per_epoch) | ||
indices = indices.astype('int') | ||
|
||
for epoch_i in range(n_epochs): | ||
for iter_i in range(iter_per_epoch - 1): | ||
batch_xs = X_train[indices[iter_i]:indices[iter_i+1]] | ||
batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]] | ||
|
||
if iter_i % 10 == 0: | ||
loss = sess.run(cross_entropy, | ||
feed_dict={ | ||
x: batch_xs, | ||
y: batch_ys, | ||
keep_prob: 1.0 | ||
}) | ||
print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss)) | ||
|
||
sess.run(optimizer, feed_dict={ | ||
x: batch_xs, y: batch_ys, keep_prob: 0.8}) | ||
|
||
|
||
print('Accuracy: ' + str(sess.run(accuracy, | ||
feed_dict={ | ||
x: X_valid, | ||
y: Y_valid, | ||
keep_prob: 1.0 | ||
}))) | ||
#theta = sess.run(h_fc_loc2, feed_dict={ | ||
# x: batch_xs, keep_prob: 1.0}) | ||
#print(theta[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
### How to get the data | ||
|
||
#### Cluttered MNIST | ||
|
||
The cluttered MNIST dataset can be found here [1] or can be generated via [2]. | ||
|
||
Settings used for `cluttered_mnist.py` : | ||
|
||
```python | ||
|
||
ORG_SHP = [28, 28] | ||
OUT_SHP = [40, 40] | ||
NUM_DISTORTIONS = 8 | ||
dist_size = (5, 5) | ||
|
||
``` | ||
|
||
[1] https://github.com/daviddao/spatial-transformer-tensorflow | ||
|
||
[2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
import tensorflow as tf | ||
from spatial_transformer import transformer | ||
from scipy import ndimage | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from tf_utils import conv2d, linear, weight_variable, bias_variable | ||
|
||
# %% Create a batch of three images (1600 x 1200) | ||
# %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg | ||
im = ndimage.imread('cat.jpg') | ||
im = im / 255. | ||
im = im.reshape(1, 1200, 1600, 3) | ||
im = im.astype('float32') | ||
|
||
# %% Simulate batch | ||
batch = np.append(im, im, axis=0) | ||
batch = np.append(batch, im, axis=0) | ||
num_batch = 3 | ||
|
||
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3]) | ||
x = tf.cast(batch,'float32') | ||
|
||
# %% Create localisation network and convolutional layer | ||
with tf.variable_scope('spatial_transformer_0'): | ||
|
||
# %% Create a fully-connected layer with 6 output nodes | ||
n_fc = 6 | ||
W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1') | ||
|
||
# %% Zoom into the image | ||
initial = np.array([[0.5,0, 0],[0,0.5,0]]) | ||
initial = initial.astype('float32') | ||
initial = initial.flatten() | ||
|
||
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1') | ||
h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1 | ||
h_trans = transformer(x, h_fc1, downsample_factor=2) | ||
|
||
# %% Run session | ||
sess = tf.Session() | ||
sess.run(tf.initialize_all_variables()) | ||
y = sess.run(h_trans, feed_dict={x: batch}) | ||
|
||
# plt.imshow(y[0]) |
Oops, something went wrong.