Skip to content

Commit

Permalink
Adding network3.py and expand_mnist.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mnielsen committed Dec 30, 2014
1 parent 8037bb5 commit e1f1bf2
Show file tree
Hide file tree
Showing 3 changed files with 367 additions and 1 deletion.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
*.org
*.pkl
*.pyc
.DS_Store
.DS_Store
loc.py
src/ec2
60 changes: 60 additions & 0 deletions src/expand_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""expand_mnist.py
~~~~~~~~~~~~~~~~~~
Take the 50,000 MNIST training images, and create an expanded set of
250,000 images, by displacing each training image up, down, left and
right, by one pixel. Save the resulting file to
../data/mnist_expanded.pkl.gz.
Note that this program is memory intensive, and may not run on small
systems.
"""

from __future__ import print_function

#### Libraries

# Standard library
import cPickle
import gzip
import os.path
import random

# Third-party libraries
import numpy as np

print("Expanding the MNIST training set")

if os.path.exists("../data/mnist_expanded.pkl.gz"):
print("The expanded training set already exists. Exiting.")
else:
f = gzip.open("../data/mnist.pkl.gz", 'rb')
training_data, validation_data, test_data = cPickle.load(f)
f.close()
expanded_training_pairs = []
j = 0 # counter
for x, y in zip(training_data[0], training_data[1]):
expanded_training_pairs.append((x, y))
image = np.reshape(x, (-1, 28))
j += 1
if j % 1000 == 0: print("Expanding image number", j)
# iterate over data telling us the details of how to
# do the displacement
for d, axis, index_position, index in [
(1, 0, "first", 0),
(-1, 0, "first", 27),
(1, 1, "last", 0),
(-1, 1, "last", 27)]:
new_img = np.roll(image, d, axis)
if index_position == "first":
new_img[index, :] = np.zeros(28)
else:
new_img[:, index] = np.zeros(28)
expanded_training_pairs.append((np.reshape(new_img, 784), y))
random.shuffle(expanded_training_pairs)
expanded_training_data = [list(d) for d in zip(*expanded_training_pairs)]
print("Saving expanded data. This may take a few minutes.")
f = gzip.open("../data/mnist_expanded.pkl.gz", "w")
cPickle.dump((expanded_training_data, validation_data, test_data), f)
f.close()
304 changes: 304 additions & 0 deletions src/network3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
"""network3.py
~~~~~~~~~~~~~~
A Theano-based program for training and running simple neural
networks.
Supports several layer types (fully connected, convolutional, max
pooling, softmax), and activation functions (sigmoid, tanh, and
rectified linear units, with more easily added).
When run on a CPU, this program is much faster than network.py and
network2.py. However, unlike network.py and network2.py it can also
be run on a GPU, which makes it faster still.
Because the code is based on Theano, the code is different in many
ways from network.py and network2.py. However, where possible I have
tried to maintain consistency with the earlier programs. In
particular, the API is similar to network2.py. Note that I have
focused on making the code simple, easily readable, and easily
modifiable. It is not optimized, and omits many desirable features.
"""

#### Libraries
# Standard library
import cPickle
import gzip

# Third-party libraries
import numpy as np
import theano
import theano.tensor as T
from theano.tensor.nnet import conv
from theano.tensor.nnet import softmax
from theano.tensor.signal import downsample

# Activation functions for neurons
def linear(z): return z
def ReLU(z): return T.maximum(0, z)
from theano.tensor.nnet import sigmoid
from theano.tensor import tanh


#### Constants
GPU = False
if GPU:
print "Trying to run under a GPU. If this is not desired, then modify "+\
"network3.py\nto set the GPU flag to False."
try: theano.config.device = 'gpu'
except: pass # it's already set
theano.config.floatX = 'float32'

def example(mini_batch_size=10):
print("Loading the MNIST data")
training_data, validation_data, test_data = load_data_shared("../data/mnist.pkl.gz")
print("Building the network")
net = create_net(10)
print("Training the network")
try:
net.SGD(training_data, 200, mini_batch_size, 0.1,
validation_data, test_data, lmbda=1.0)
except KeyboardInterrupt:
pass
return net

def create_net(mini_batch_size=10, activation_fn=tanh):
return Network(
[ConvPoolLayer(image_shape=(mini_batch_size, 1, 28, 28), filter_shape=(20, 1, 5, 5), poolsize=(2, 2), activation_fn=activation_fn),
#ConvPoolLayer(image_shape=(mini_batch_size, 20, 12, 12), filter_shape=(40, 20, 5, 5), poolsize=(2, 2), activation_fn=activation_fn),
#FullyConnectedLayer(n_in=40*4*4, n_out=100, mini_batch_size=mini_batch_size, activation_fn=activation_fn),
#FullyConnectedLayer(n_in=784, n_out=100, mini_batch_size=mini_batch_size, activation_fn=activation_fn),
#FullyConnectedLayer(n_in=20*12*12, n_out=100, mini_batch_size=mini_batch_size),
#FullyConnectedLayer(n_in=100, n_out=100, mini_batch_size=mini_batch_size, activation_fn=activation_fn),
#SoftmaxLayer(n_in=100, n_out=10, mini_batch_size=mini_batch_size)], mini_batch_size)
SoftmaxLayer(n_in=20*12*12, n_out=10)], mini_batch_size)

#### Load the MNIST data
def load_data_shared(filename="../data/mnist.pkl.gz"):
f = gzip.open(filename, 'rb')
training_data, validation_data, test_data = cPickle.load(f)
f.close()
def shared(data):
"""Place the data into shared variables. This allows Theano to copy
the data to the GPU, if one is available.
"""
shared_x = theano.shared(
np.asarray(data[0], dtype=theano.config.floatX), borrow=True)
shared_y = theano.shared(
np.asarray(data[1], dtype=theano.config.floatX), borrow=True)
return shared_x, T.cast(shared_y, "int32")
return [shared(training_data), shared(validation_data), shared(test_data)]

#### Main class used to construct and train networks
class Network():

def __init__(self, layers, mini_batch_size):
"""Takes a list of `layers`, describing the network architecture, and
a value for the `mini_batch_size` to be used during training
by stochastic gradient descent.
"""
self.layers = layers
self.mini_batch_size = mini_batch_size
self.params = [param for layer in self.layers for param in layer.params]
self.x = T.matrix("x")
self.y = T.ivector("y")
init_layer = self.layers[0]
init_layer.set_inpt(self.x, mini_batch_size)
for j in xrange(1, len(self.layers)):
prev_layer, layer = self.layers[j-1], self.layers[j]
layer.set_inpt(prev_layer.output, mini_batch_size)
self.output = self.layers[-1].output

def SGD(self, training_data, epochs, mini_batch_size, eta,
validation_data, test_data, lmbda=0.0):
"""Train the network using mini-batch stochastic gradient descent."""
training_x, training_y = training_data
validation_x, validation_y = validation_data
test_x, test_y = test_data

# compute number of minibatches for training, validation and testing
num_training_batches = size(training_data)/mini_batch_size
num_validation_batches = size(validation_data)/mini_batch_size
num_test_batches = size(test_data)/mini_batch_size

# define the (regularized) cost function, symbolic gradients, and updates
l2_norm_squared = sum([(layer.w**2).sum() for layer in self.layers])
cost = self.log_likelihood()+0.5*lmbda*l2_norm_squared/num_training_batches
grads = T.grad(cost, self.params)
updates = [(param, param-eta*grad)
for param, grad in zip(self.params, grads)]

# define functions to train a mini-batch, and to compute the
# accuracy in validation and test mini-batches.
i = T.lscalar() # mini-batch index
train_mb = theano.function(
[i], cost, updates=updates,
givens={
self.x:
training_x[i*self.mini_batch_size: (i+1)*self.mini_batch_size],
self.y:
training_y[i*self.mini_batch_size: (i+1)*self.mini_batch_size]
})
validate_mb_accuracy = theano.function(
[i], self.layers[-1].accuracy(self.y),
givens={
self.x:
validation_x[i*self.mini_batch_size: (i+1)*self.mini_batch_size],
self.y:
validation_y[i*self.mini_batch_size: (i+1)*self.mini_batch_size]
})
test_mb_accuracy = theano.function(
[i], self.layers[-1].accuracy(self.y),
givens={
self.x:
test_x[i*self.mini_batch_size: (i+1)*self.mini_batch_size],
self.y:
test_y[i*self.mini_batch_size: (i+1)*self.mini_batch_size]
})

# Do the actual training
best_validation_accuracy = 0.0
for epoch in xrange(epochs):
for minibatch_index in xrange(num_training_batches):
iteration = num_training_batches*epoch+minibatch_index
if iteration % 1000 == 0:
print("Training mini-batch number {0}".format(iteration))
cost_ij = train_mini_batch(minibatch_index)
if (iteration+1) % num_training_batches == 0:
validation_accuracy = np.mean(
[validate_mb_accuracy(j) for j in xrange(num_validation_batches)])
print("Epoch {0}: validation accuracy {1:.2%}".format(
epoch, validation_accuracy))
if validation_accuracy >= best_validation_accuracy:
print("This is the best validation accuracy to date.")
best_validation_accuracy = validation_accuracy
best_iteration = iteration
test_accuracy = np.mean(
[test_mb_accuracy(j) for j in xrange(num_test_batches)])
print('The corresponding test accuracy is {0:.2%}'.format(
test_accuracy))
print("Finished training network.")
print("Best validation accuracy of {0:.2%} obtained at iteration {1}".format(
best_validation_accuracy, best_iteration))
print("Corresponding test accuracy of {0:.2%}".format(test_accuracy))

def log_likelihood(self):
"Return the log-likelihood cost."
return -T.mean(T.log(self.output)[T.arange(self.y.shape[0]), self.y])


#### Define layer types

class ConvPoolLayer():
"""Used to create a combination of a convolutional and a max-pooling
layer. A more sophisticated implementation would separate the
two, but for our purposes we'll always use them together, and it
simplifies the code, so it makes sense to combine them.
"""

def __init__(self, filter_shape, image_shape, poolsize=(2, 2),
activation_fn=sigmoid):
"""`filter_shape` is a tuple of length 4, whose entries are the number
of filters, the number of input feature maps, the filter height, and the
filter width.
`image_shape` is a tuple of length 4, whose entries are the
mini-batch size, the number of input feature maps, the image
height, and the image width.
`poolsize` is a tuple of length 2, whose entries are the y and
x pooling sizes.
"""
self.inpt = None
self.output = None
self.filter_shape = filter_shape
self.image_shape = image_shape
self.poolsize = poolsize
self.activation_fn=activation_fn
# initialize weights and biases
n_out = (filter_shape[0]*np.prod(filter_shape[2:])/np.prod(poolsize))
self.w = theano.shared(
np.asarray(
np.random.normal(loc=0, scale=np.sqrt(1.0/n_out), size=filter_shape),
dtype=theano.config.floatX),
borrow=True)
self.b = theano.shared(
np.asarray(
np.random.normal(loc=0, scale=1.0, size=(filter_shape[0],)),
dtype=theano.config.floatX),
borrow=True)
self.params = [self.w, self.b]

def set_inpt(self, inpt, mini_batch_size):
self.inpt = inpt.reshape(self.image_shape)
conv_out = conv.conv2d(
input=self.inpt, filters=self.w, filter_shape=self.filter_shape,
image_shape=self.image_shape)
pooled_out = downsample.max_pool_2d(
input=conv_out, ds=self.poolsize, ignore_border=True)
self.output = self.activation_fn(
pooled_out + self.b.dimshuffle('x', 0, 'x', 'x'))


class FullyConnectedLayer():

def __init__(self, n_in, n_out, mini_batch_size=10, activation_fn=sigmoid):
self.n_in = n_in
self.n_out = n_out
self.activation_fn = activation_fn
self.inpt = None
self.output = None
# Initialize weights and biases
self.w = theano.shared(
np.asarray(
np.random.normal(
loc=0.0, scale=np.sqrt(1.0/n_out), size=(n_in, n_out)),
dtype=theano.config.floatX),
name='w', borrow=True)
self.b = theano.shared(
np.asarray(np.random.normal(loc=0.0, scale=1.0, size=(n_out,)),
dtype=theano.config.floatX),
name='b', borrow=True)
self.params = [self.w, self.b]

def set_inpt(self, inpt, mini_batch_size):
self.mini_batch_size = mini_batch_size
self.inpt = inpt.reshape((self.mini_batch_size, self.n_in))
self.output = self.activation_fn(T.dot(inpt, self.w)+self.b)

class SoftmaxLayer():

def __init__(self, n_in, n_out):
self.inpt = None
self.output = None
self.n_in = n_in
self.n_out = n_out
# Initialize weights and biases
self.w = theano.shared(
np.zeros((n_in, n_out), dtype=theano.config.floatX),
name='w', borrow=True)
self.b = theano.shared(
np.zeros((n_out,), dtype=theano.config.floatX),
name='b', borrow=True)
self.params = [self.w, self.b]

def set_inpt(self, inpt, mini_batch_size):
self.mini_batch_size = mini_batch_size
self.inpt = inpt.reshape((self.mini_batch_size, self.n_in))
self.output = softmax(T.dot(self.inpt, self.w)+self.b)
self.y_out = T.argmax(self.output, axis=1)

def accuracy(self, y):
"Return the accuracy for the mini-batch."
return T.mean(T.eq(y, self.y_out))


#### Miscellanea
def size(data):
"Return the size of the dataset `data`."
return data[0].get_value(borrow=True).shape[0]

0 comments on commit e1f1bf2

Please sign in to comment.