Skip to content

Commit

Permalink
Save models with better validation hit accuracy (tflearn#265)
Browse files Browse the repository at this point in the history
* Update documentation for correct default optimizer in regression

* Provide option to save model as validation hit rate increases

* Append accuracy rounded to 4 digits and  mult by 10000 to remove all decimals

* Skip saving batch model if step number is not set

* Update documentation

* Remove duplicate param setting

* Update documentation
  • Loading branch information
braddengross authored and aymericdamien committed Aug 10, 2016
1 parent 0a9adcd commit bc0d9fa
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
19 changes: 16 additions & 3 deletions tflearn/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,15 @@ def snapshot_termlogs(self):


class ModelSaver(object):
def __init__(self, save_func, training_step, snapshot_path, snapshot_epoch):
def __init__(self, save_func, training_step, snapshot_path, best_snapshot_path,
best_val_accuracy, snapshot_step, snapshot_epoch):
self.save_func = save_func
self.training_step = training_step
self.snapshot_path = snapshot_path
self.snapshot_epoch = snapshot_epoch
self.best_snapshot_path = best_snapshot_path
self.snapshot_step = snapshot_step
self.best_val_accuracy = best_val_accuracy

def on_epoch_begin(self):
pass
Expand All @@ -229,10 +233,14 @@ def on_sub_epoch_end(self):
def on_batch_begin(self):
pass

def on_batch_end(self, snapshot_model=False):
def on_batch_end(self, snapshot_model=False, best_checkpoint_path=None, val_accuracy=None):
self.training_step += 1
if snapshot_model:
if snapshot_model & (self.snapshot_step is not None):
self.save()
if None not in (best_checkpoint_path, val_accuracy, self.best_val_accuracy):
if val_accuracy > self.best_val_accuracy:
self.best_val_accuracy = val_accuracy
self.save_best(int(10000 * round(val_accuracy, 4)))

def on_sub_batch_begin(self):
pass
Expand All @@ -249,3 +257,8 @@ def on_train_end(self):
def save(self):
if self.snapshot_path:
self.save_func(self.snapshot_path, self.training_step)

def save_best(self, val_accuracy):
if self.best_snapshot_path:
snapshot_path = self.best_snapshot_path + str(val_accuracy)
self.save_func(snapshot_path)
21 changes: 17 additions & 4 deletions tflearn/helpers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class Trainer(object):
```
checkpoint_path: `str`. Path to store model checkpoints. If None,
no model checkpoint will be saved. Default: None.
best_checkpoint_path: `str`. Path to store the model when the validation rate reaches its
highest point of the current training session and also is above best_val_accuracy. Default: None.
max_checkpoints: `int` or None. Maximum amount of checkpoints. If
None, no limit. Default: None.
keep_checkpoint_every_n_hours: `float`. Number of hours between each
Expand All @@ -50,15 +52,19 @@ class Trainer(object):
session: `Session`. A session for running ops. If None, a new one will
be created. Note: When providing a session, variables must have been
initialized already, otherwise an error will be raised.
best_val_accuracy: `float` The minimum validation accuracy that needs to be
achieved before a model weight's are saved to the best_checkpoint_path. This
allows the user to skip early saves and also set a minimum save point when continuing
to train a reloaded model. Default: 0.0.
"""

def __init__(self, train_ops, graph=None, clip_gradients=5.0,
tensorboard_dir="/tmp/tflearn_logs/",
tensorboard_verbose=0, checkpoint_path=None,
tensorboard_verbose=0, checkpoint_path=None, best_checkpoint_path=None,
max_checkpoints=None,
keep_checkpoint_every_n_hours=10000.0, random_seed=None,
session=None):
session=None, best_val_accuracy=0.0):

self.graph = tf.get_default_graph()
if graph:
Expand All @@ -85,6 +91,8 @@ def __init__(self, train_ops, graph=None, clip_gradients=5.0,
trainable=False)
self.incr_global_step = tf.assign(self.global_step,
tf.add(self.global_step, 1))
self.best_val_accuracy = best_val_accuracy
self.best_checkpoint_path = best_checkpoint_path

config = None
tflearn_conf = tf.get_collection(tf.GraphKeys.GRAPH_CONFIG)
Expand Down Expand Up @@ -230,6 +238,9 @@ def fit(self, feed_dicts, n_epoch=10, val_feed_dicts=None, show_metric=False,
modelsaver = callbacks.ModelSaver(self.save,
self.training_step,
self.checkpoint_path,
self.best_checkpoint_path,
self.best_val_accuracy,
snapshot_step,
snapshot_epoch)

for i, train_op in enumerate(self.train_ops):
Expand Down Expand Up @@ -279,7 +290,7 @@ def fit(self, feed_dicts, n_epoch=10, val_feed_dicts=None, show_metric=False,
modelsaver.on_sub_batch_begin()

snapshot = train_op._train(self.training_step,
snapshot_epoch,
(bool(self.best_checkpoint_path) | snapshot_epoch),
snapshot_step,
show_metric)
global_loss += train_op.loss_value
Expand All @@ -303,7 +314,9 @@ def fit(self, feed_dicts, n_epoch=10, val_feed_dicts=None, show_metric=False,
self.session.run(self.incr_global_step)
termlogger.on_batch_end(global_loss, global_acc,
snapshot)
modelsaver.on_batch_end(snapshot)
modelsaver.on_batch_end(snapshot, self.best_checkpoint_path, train_op.val_acc)
if self.best_checkpoint_path:
self.best_val_accuracy = modelsaver.best_val_accuracy

# Epoch end
termlogger.on_epoch_end()
Expand Down
14 changes: 11 additions & 3 deletions tflearn/models/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,17 @@ class DNN(object):
Default: "/tmp/tflearn_logs/"
checkpoint_path: `str`. Path to store model checkpoints. If None,
no model checkpoint will be saved. Default: None.
best_checkpoint_path: `str`. Path to store the model when the validation rate reaches its
highest point of the current training session and also is above best_val_accuracy. Default: None.
max_checkpoints: `int` or None. Maximum amount of checkpoints. If
None, no limit. Default: None.
session: `Session`. A session for running ops. If None, a new one will
be created. Note: When providing a session, variables must have been
initialized already, otherwise an error will be raised.
best_val_accuracy: `float` The minimum validation accuracy that needs to be
achieved before a model weight's are saved to the best_checkpoint_path. This
allows the user to skip early saves and also set a minimum save point when continuing
to train a reloaded model. Default: 0.0.
Attributes:
trainer: `Trainer`. Handle model training.
Expand All @@ -41,8 +47,8 @@ class DNN(object):
"""

def __init__(self, network, clip_gradients=5.0, tensorboard_verbose=0,
tensorboard_dir="/tmp/tflearn_logs/", checkpoint_path=None,
max_checkpoints=None, session=None):
tensorboard_dir="/tmp/tflearn_logs/", checkpoint_path=None, best_checkpoint_path=None,
max_checkpoints=None, session=None, best_val_accuracy=0.0):
assert isinstance(network, tf.Tensor), "'network' arg is not a Tensor!"
self.net = network
self.train_ops = tf.get_collection(tf.GraphKeys.TRAIN_OPS)
Expand All @@ -51,8 +57,10 @@ def __init__(self, network, clip_gradients=5.0, tensorboard_verbose=0,
tensorboard_dir=tensorboard_dir,
tensorboard_verbose=tensorboard_verbose,
checkpoint_path=checkpoint_path,
best_checkpoint_path=best_checkpoint_path,
max_checkpoints=max_checkpoints,
session=session)
session=session,
best_val_accuracy=best_val_accuracy)
self.session = self.trainer.session

self.inputs = tf.get_collection(tf.GraphKeys.INPUTS)
Expand Down

0 comments on commit bc0d9fa

Please sign in to comment.