Skip to content

Commit

Permalink
allow empty savers by default (tflearn#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
aymericdamien committed Mar 15, 2017
1 parent 8e5782c commit f57a533
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tflearn/helpers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def __init__(self, train_ops, graph=None, clip_gradients=5.0,
# Saver for saving a model
self.saver = tf.train.Saver(
max_to_keep=max_checkpoints,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
allow_empty=True)
# Saver for restoring a model (With exclude variable list)
all_vars = variables.get_all_variables()
excl_vars = tf.get_collection(tf.GraphKeys.EXCL_RESTORE_VARS)
Expand All @@ -142,14 +143,16 @@ def __init__(self, train_ops, graph=None, clip_gradients=5.0,
self.restorer = tf.train.Saver(
var_list=to_restore,
max_to_keep=max_checkpoints,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
allow_empty=True)
# A second Saver, that only restore trainable variables
to_restore_trainvars = [item for item in tf.trainable_variables()
if check_restore_tensor(item, excl_vars)]
self.restorer_trainvars = tf.train.Saver(
var_list=to_restore_trainvars,
max_to_keep=max_checkpoints,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
allow_empty=True)

self.to_restore = to_restore
self.to_restore_trainvars = to_restore_trainvars
Expand Down

0 comments on commit f57a533

Please sign in to comment.