Skip to content

Commit

Permalink
Fix TF2 3D Unet to standard model garden recommended style.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 306752053
  • Loading branch information
tensorflower-gardener committed Apr 16, 2020
1 parent 5741cef commit 795a3f7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
15 changes: 2 additions & 13 deletions official/nlp/nhnet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from official.nlp.nhnet import optimizer
from official.nlp.transformer import metrics as transformer_metrics
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -122,18 +123,6 @@ def train_step(self, inputs):
}


class SimpleCheckpoint(tf.keras.callbacks.Callback):
"""Keras callback to save tf.train.Checkpoints."""

def __init__(self, checkpoint_manager):
super(SimpleCheckpoint, self).__init__()
self.checkpoint_manager = checkpoint_manager

def on_epoch_end(self, epoch, logs=None):
step_counter = self.checkpoint_manager._step_counter.numpy()
self.checkpoint_manager.save(checkpoint_number=step_counter)


def train(params, strategy, dataset=None):
"""Runs training."""

Expand Down Expand Up @@ -168,7 +157,7 @@ def train(params, strategy, dataset=None):
if checkpoint_manager.restore_or_initialize():
logging.info("Training restored from the checkpoints in: %s",
FLAGS.model_dir)
checkpoint_callback = SimpleCheckpoint(checkpoint_manager)
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)

# Trains the model.
steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval)
Expand Down
12 changes: 12 additions & 0 deletions official/utils/misc/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
return ProfilerCallback(model_dir, start_step, stop_step, steps_per_epoch)


class SimpleCheckpoint(tf.keras.callbacks.Callback):
"""Keras callback to save tf.train.Checkpoints."""

def __init__(self, checkpoint_manager):
super(SimpleCheckpoint, self).__init__()
self.checkpoint_manager = checkpoint_manager

def on_epoch_end(self, epoch, logs=None):
step_counter = self.checkpoint_manager._step_counter.numpy() # pylint: disable=protected-access
self.checkpoint_manager.save(checkpoint_number=step_counter)


class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory."""

Expand Down

0 comments on commit 795a3f7

Please sign in to comment.