-
Notifications
You must be signed in to change notification settings - Fork 653
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
XifengGuo
committed
Nov 20, 2017
1 parent
3ddc9b4
commit e714267
Showing
4 changed files
with
174 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
""" | ||
Keras implementation of CapsNet in Hinton's paper Dynamic Routing Between Capsules. | ||
The current version maybe only works for TensorFlow backend. Actually it will be straightforward to re-write to TF code. | ||
Adopting to other backends should be easy, but I have not tested this. | ||
Usage: | ||
python capsulenet-multi-gpu.py | ||
python capsulenet-multi-gpu.py --gpus 2 | ||
... ... | ||
Result: | ||
About 55 seconds per epoch on two GTX1080Ti GPU cards | ||
Author: Xifeng Guo, E-mail: `[email protected]`, Github: `https://github.com/XifengGuo/CapsNet-Keras` | ||
""" | ||
|
||
from keras import optimizers | ||
from keras import backend as K | ||
|
||
K.set_image_data_format('channels_last') | ||
|
||
from capsulenet import CapsNet, margin_loss, load_mnist | ||
|
||
|
||
def train(model, data, args): | ||
""" | ||
Training a CapsuleNet | ||
:param model: the CapsuleNet model | ||
:param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))` | ||
:param args: arguments | ||
:return: The trained model | ||
""" | ||
# unpacking the data | ||
(x_train, y_train), (x_test, y_test) = data | ||
|
||
# callbacks | ||
log = callbacks.CSVLogger(args.save_dir + '/log.csv') | ||
tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs', | ||
batch_size=args.batch_size, histogram_freq=args.debug) | ||
lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (0.9 ** epoch)) | ||
|
||
# compile the model | ||
model.compile(optimizer=optimizers.Adam(lr=args.lr), | ||
loss=[margin_loss, 'mse'], | ||
loss_weights=[1., args.lam_recon]) | ||
|
||
""" | ||
# Training without data augmentation: | ||
model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs, | ||
validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay]) | ||
""" | ||
|
||
# Begin: Training with data augmentation ---------------------------------------------------------------------# | ||
def train_generator(x, y, batch_size, shift_fraction=0.): | ||
train_datagen = ImageDataGenerator(width_shift_range=shift_fraction, | ||
height_shift_range=shift_fraction) # shift up to 2 pixel for MNIST | ||
generator = train_datagen.flow(x, y, batch_size=batch_size) | ||
while 1: | ||
x_batch, y_batch = generator.next() | ||
yield ([x_batch, y_batch], [y_batch, x_batch]) | ||
|
||
# Training with data augmentation. If shift_fraction=0., also no augmentation. | ||
model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction), | ||
steps_per_epoch=int(y_train.shape[0] / args.batch_size), | ||
epochs=args.epochs, | ||
validation_data=[[x_test, y_test], [y_test, x_test]], | ||
callbacks=[log, tb, lr_decay]) | ||
# End: Training with data augmentation -----------------------------------------------------------------------# | ||
|
||
from utils import plot_log | ||
plot_log(args.save_dir + '/log.csv', show=True) | ||
|
||
return model | ||
|
||
|
||
def test(model, data): | ||
x_test, y_test = data | ||
y_pred, x_recon = model.predict(x_test, batch_size=100) | ||
print('-'*50) | ||
print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0]) | ||
|
||
import matplotlib.pyplot as plt | ||
from utils import combine_images | ||
from PIL import Image | ||
|
||
img = combine_images(np.concatenate([x_test[:50],x_recon[:50]])) | ||
image = img * 255 | ||
Image.fromarray(image.astype(np.uint8)).save("real_and_recon.png") | ||
print() | ||
print('Reconstructed images are saved to ./real_and_recon.png') | ||
print('-'*50) | ||
plt.imshow(plt.imread("real_and_recon.png", )) | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
import numpy as np | ||
import tensorflow as tf | ||
import os | ||
from keras.preprocessing.image import ImageDataGenerator | ||
from keras import callbacks | ||
from keras.utils.vis_utils import plot_model | ||
from keras.utils import multi_gpu_model | ||
|
||
# setting the hyper parameters | ||
import argparse | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--batch_size', default=300, type=int) | ||
parser.add_argument('--epochs', default=50, type=int) | ||
parser.add_argument('--lam_recon', default=0.392, type=float) # 784 * 0.0005, paper uses sum of SE, here uses MSE | ||
parser.add_argument('--num_routing', default=3, type=int) # num_routing should > 0 | ||
parser.add_argument('--shift_fraction', default=0.1, type=float) | ||
parser.add_argument('--debug', default=0, type=int) # debug>0 will save weights by TensorBoard | ||
parser.add_argument('--save_dir', default='./result') | ||
parser.add_argument('--is_training', default=1, type=int) | ||
parser.add_argument('--weights', default=None) | ||
parser.add_argument('--lr', default=0.001, type=float) | ||
parser.add_argument('--gpus', default=2, type=int) | ||
args = parser.parse_args() | ||
print(args) | ||
if not os.path.exists(args.save_dir): | ||
os.makedirs(args.save_dir) | ||
|
||
# load data | ||
(x_train, y_train), (x_test, y_test) = load_mnist() | ||
|
||
# define model | ||
with tf.device('/cpu:0'): | ||
model, eval_model = CapsNet(input_shape=x_train.shape[1:], | ||
n_class=len(np.unique(np.argmax(y_train, 1))), | ||
num_routing=args.num_routing) | ||
model.summary() | ||
plot_model(model, to_file=args.save_dir+'/model.png', show_shapes=True) | ||
|
||
# define muti-gpu model | ||
multi_model = multi_gpu_model(model, gpus=args.gpus) | ||
# train or test | ||
if args.weights is not None: # init the model weights with provided one | ||
model.load_weights(args.weights) | ||
if args.is_training: | ||
train(model=multi_model, data=((x_train, y_train), (x_test, y_test)), args=args) | ||
model.save_weights(args.save_dir + '/trained_model.h5') | ||
print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir) | ||
test(model=eval_model, data=(x_test, y_test)) | ||
else: # as long as weights are given, will run testing | ||
if args.weights is None: | ||
print('No weights are provided. Will test using random initialized weights.') | ||
test(model=eval_model, data=(x_test, y_test)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
Author: Xifeng Guo, E-mail: `[email protected]`, Github: `https://github.com/XifengGuo/CapsNet-Keras` | ||
""" | ||
|
||
import numpy as np | ||
from keras import layers, models, optimizers | ||
from keras import backend as K | ||
from keras.utils import to_categorical | ||
|
@@ -24,7 +25,7 @@ | |
K.set_image_data_format('channels_last') | ||
|
||
|
||
def CapsNet(input_shape, n_class, num_routing, batch_size): | ||
def CapsNet(input_shape, n_class, num_routing): | ||
""" | ||
A Capsule Network on MNIST. | ||
:param input_shape: data shape, 3d, [width, height, channels] | ||
|
@@ -42,7 +43,7 @@ def CapsNet(input_shape, n_class, num_routing, batch_size): | |
primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid') | ||
|
||
# Layer 3: Capsule layer. Routing algorithm works here. | ||
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, batch_size=batch_size, num_routing=num_routing, | ||
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, num_routing=num_routing, | ||
name='digitcaps')(primarycaps) | ||
|
||
# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. | ||
|
@@ -170,7 +171,6 @@ def load_mnist(): | |
|
||
|
||
if __name__ == "__main__": | ||
import numpy as np | ||
import os | ||
from keras.preprocessing.image import ImageDataGenerator | ||
from keras import callbacks | ||
|
@@ -200,8 +200,7 @@ def load_mnist(): | |
# define model | ||
model, eval_model = CapsNet(input_shape=x_train.shape[1:], | ||
n_class=len(np.unique(np.argmax(y_train, 1))), | ||
num_routing=args.num_routing, | ||
batch_size=args.batch_size) | ||
num_routing=args.num_routing) | ||
model.summary() | ||
plot_model(model, to_file=args.save_dir+'/model.png', show_shapes=True) | ||
|
||
|