Skip to content

Commit

Permalink
Use SGD instead of Adam
Browse files Browse the repository at this point in the history
  • Loading branch information
d4nst committed Mar 28, 2018
1 parent b845ebf commit 5f8cf12
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions train/train_street_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import os
import sys

from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
from keras.applications.resnet50 import ResNet50
from keras.applications.imagenet_utils import preprocess_input
from keras.models import Model
from keras.layers import Dense, Flatten
from keras.optimizers import SGD

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import angle_error, RotNetDataGenerator
Expand Down Expand Up @@ -43,7 +44,7 @@

# model compilation
model.compile(loss='categorical_crossentropy',
optimizer='adam',
optimizer=SGD(lr=0.01, momentum=0.9),
metrics=[angle_error])

# training parameters
Expand All @@ -55,11 +56,14 @@
os.makedirs(output_folder)

# callbacks
monitor = 'val_angle_error'
checkpointer = ModelCheckpoint(
filepath=os.path.join(output_folder, model_name + '.hdf5'),
monitor=monitor,
save_best_only=True
)
early_stopping = EarlyStopping(patience=2)
reduce_lr = ReduceLROnPlateau(monitor=monitor, patience=3)
early_stopping = EarlyStopping(monitor=monitor, patience=5)
tensorboard = TensorBoard()

# training loop
Expand All @@ -84,8 +88,6 @@
crop_largest_rect=True
),
validation_steps=len(test_filenames) / batch_size,
callbacks=[checkpointer, early_stopping, tensorboard],
nb_worker=10,
pickle_safe=True,
verbose=1
callbacks=[checkpointer, reduce_lr, early_stopping, tensorboard],
workers=10
)

0 comments on commit 5f8cf12

Please sign in to comment.