Skip to content

Commit

Permalink
Merge pull request keras-team#1651 from tboquet/fit_gen_tensorb
Browse files Browse the repository at this point in the history
Support + tests for fit_generator + tensorboard
  • Loading branch information
fchollet committed Feb 8, 2016
2 parents cae797b + d68e331 commit 657b9fb
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 25 deletions.
5 changes: 3 additions & 2 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,14 +456,15 @@ def __init__(self, log_dir='./logs', histogram_freq=0):
'with the TensorFlow backend.')
self.log_dir = log_dir
self.histogram_freq = histogram_freq
self.merged = None

def _set_model(self, model):
import tensorflow as tf
import keras.backend.tensorflow_backend as KTF

self.model = model
self.sess = KTF._get_session()
if self.histogram_freq:
if self.histogram_freq and not self.merged:
mod_type = self.model.get_config()['name']
if mod_type == 'Sequential':
layers = {l.get_config()['name']: l for l in self.model.layers}
Expand Down Expand Up @@ -515,7 +516,7 @@ def on_epoch_end(self, epoch, logs={}):

all_values = self.totals.copy()
all_values.update(logs)

for name, value in all_values.items():
if name in ['batch', 'size']:
continue
Expand Down
31 changes: 24 additions & 7 deletions keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,8 +971,17 @@ def input_validation(generator_output):
_stop.set()
raise Exception('The generator output tuple must have '
'2 or 3 elements.')

sample_weight = standardize_weights(y, sample_weight=sample_weight,
sample_weight_mode=self.sample_weight_mode)
return X, y, sample_weight

if do_validation:
X_val, y_val, sample_weight_val = input_validation(validation_data)
self.validation_data = X_val + [y_val, sample_weight_val]
else:
self.validation_data = None

# start generator thread storing batches into a queue
generator_queue = queue.Queue()
_stop = threading.Event()
Expand Down Expand Up @@ -1044,10 +1053,9 @@ def generator_task():
raise NotImplementedError()
else:
# input validation
X, y, sample_weight = input_validation(validation_data)
val_outs = self.evaluate(X, y,
val_outs = self.evaluate(X_val, y_val,
show_accuracy=show_accuracy,
sample_weight=sample_weight,
sample_weight=sample_weight_val,
verbose=0)
if type(val_outs) != list:
val_outs = [val_outs]
Expand Down Expand Up @@ -1435,8 +1443,19 @@ def input_validation(generator_output):
[len(sample_weight[name]) for name in sample_weight.keys()])) != 1:
raise Exception('All input arrays and target arrays must have '
'the same number of samples.')
sample_weight = {name: standardize_weights(data[name],
sample_weight=sample_weight.get(name),
sample_weight_mode=self.sample_weight_modes.get(name)) for name in self.output_order}
return data, sample_weight

if do_validation:
data_val, sample_weight_val = input_validation(validation_data)
sample_weight_val_l = [sample_weight_val[name] for name in self.output_order]
y_val = [standardize_y(data_val[name]) for name in self.output_order]
self.validation_data = [data_val[name] for name in self.input_order] + y_val + sample_weight_val_l
else:
self.validation_data = None

# start generator thread storing batches into a queue
generator_queue = queue.Queue()
_stop = threading.Event()
Expand Down Expand Up @@ -1502,10 +1521,8 @@ def generator_task():
_stop.set()
raise NotImplementedError()
else:
# input validation
data, sample_weight = input_validation(validation_data)
val_outs = self.evaluate(data,
sample_weight=sample_weight,
val_outs = self.evaluate(data_val,
sample_weight=sample_weight_val,
verbose=0)
if type(val_outs) != list:
val_outs = [val_outs]
Expand Down
86 changes: 70 additions & 16 deletions tests/keras/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,30 @@ def test_TensorBoard():
nb_class=nb_class)
y_test = np_utils.to_categorical(y_test)
y_train = np_utils.to_categorical(y_train)
# case 1 Sequential wo accuracy
with tf.Graph().as_default():
session = tf.Session('')
KTF._set_session(session)
model = Sequential()
model.add(Dense(nb_hidden, input_dim=input_dim, activation='relu'))
model.add(Dense(nb_class, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd')

tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
cbks = [tsb]
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True,
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
assert os.path.exists(filepath)
shutil.rmtree(filepath)
def data_generator(train):
if train:
max_batch_index = len(X_train) // batch_size
else:
max_batch_index = len(X_test) // batch_size
i = 0
while 1:
if train:
yield (X_train[i * batch_size: (i + 1) * batch_size], y_train[i * batch_size: (i + 1) * batch_size])
else:
yield (X_test[i * batch_size: (i + 1) * batch_size], y_test[i * batch_size: (i + 1) * batch_size])
i += 1
i = i % max_batch_index

def data_generator_graph(train):
while 1:
if train:
yield {'X_vars': X_train, 'output': y_train}
else:
yield {'X_vars': X_test, 'output': y_test}

# case 1 Sequential

# case 2 Sequential w accuracy
with tf.Graph().as_default():
session = tf.Session('')
KTF._set_session(session)
Expand All @@ -163,12 +170,42 @@ def test_TensorBoard():

tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
cbks = [tsb]

# fit with validation data
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=False,
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)

# fit with validation data and accuracy
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True,
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)

# fit generator with validation data
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
show_accuracy=False,
validation_data=(X_test, y_test),
callbacks=cbks)

# fit generator without validation data
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
show_accuracy=False,
callbacks=cbks)

# fit generator with validation data and accuracy
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
show_accuracy=True,
validation_data=(X_test, y_test),
callbacks=cbks)

# fit generator without validation data and accuracy
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
show_accuracy=True,
callbacks=cbks)

assert os.path.exists(filepath)
shutil.rmtree(filepath)

# case 3 Graph
# case 2 Graph

with tf.Graph().as_default():
session = tf.Session('')
KTF._set_session(session)
Expand All @@ -185,10 +222,27 @@ def test_TensorBoard():

tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
cbks = [tsb]

# fit with validation
model.fit({'X_vars': X_train, 'output': y_train},
batch_size=batch_size,
validation_data={'X_vars': X_test, 'output': y_test},
callbacks=cbks, nb_epoch=2)

# fit wo validation
model.fit({'X_vars': X_train, 'output': y_train},
batch_size=batch_size,
callbacks=cbks, nb_epoch=2)

# fit generator with validation
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
validation_data={'X_vars': X_test, 'output': y_test},
callbacks=cbks)

# fit generator wo validation
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
callbacks=cbks)

assert os.path.exists(filepath)
shutil.rmtree(filepath)

Expand Down

0 comments on commit 657b9fb

Please sign in to comment.