Skip to content

Commit

Permalink
generator example
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Apr 20, 2016
1 parent 8017676 commit 7f85a6c
Showing 1 changed file with 114 additions and 0 deletions.
114 changes: 114 additions & 0 deletions examples/cifar_generator_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import print_function
from hyperopt import Trials, STATUS_OK, tpe
from hyperas import optim
from hyperas.distributions import choice, uniform


def data():
from keras.preprocessing.image import ImageDataGenerator
from keras.datasets import cifar10
from keras.utils import np_utils

nb_classes = 10
# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

# this will do preprocessing and realtime data augmentation
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
horizontal_flip=True, # randomly flip images
vertical_flip=False) # randomly flip images

# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(X_train)

return datagen, X_train, Y_train, X_test, Y_test

def model(datagen, X_train, Y_train, X_test, Y_test):
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD

batch_size = 32
nb_epoch = 200

# input image dimensions
img_rows, img_cols = 32, 32
# the CIFAR10 images are RGB
img_channels = 3

model = Sequential()

model.add(Convolution2D(32, 3, 3, border_mode='same',
input_shape=(img_channels, img_rows, img_cols)))
model.add(Activation('relu'))
model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout({{uniform(0, 1)}}))

model.add(Convolution2D(64, 3, 3, border_mode='same'))
model.add(Activation('relu'))
model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout({{uniform(0, 1)}}))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))

# let's train the model using SGD + momentum (how original).
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
optimizer=sgd,
metrics=['accuracy'])

# fit the model on the batches generated by datagen.flow()
model.fit_generator(datagen.flow(X_train, Y_train,
batch_size=batch_size),
samples_per_epoch=X_train.shape[0],
nb_epoch=nb_epoch,
validation_data=(X_test, Y_test))

score, acc = model.evaluate(X_test, Y_test, verbose=0)

return {'loss': -acc, 'status': STATUS_OK, 'model': model}


if __name__ == '__main__':

datagen, X_train, Y_train, X_test, Y_test = data()

best_run, best_model = optim.minimize(model=model,
data=data,
algo=tpe.suggest,
max_evals=5,
trials=Trials())

print("Evalutation of best performing model:")
print(best_model.evaluate(X_test, Y_test))

0 comments on commit 7f85a6c

Please sign in to comment.