Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 505726524
  • Loading branch information
tensorflower-gardener committed Jan 30, 2023
1 parent c2698ea commit 03c9bfb
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 36 deletions.
4 changes: 1 addition & 3 deletions orbit/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,8 @@ def train_and_evaluate(self,
interval = min(train_steps - current_step, eval_interval)
num_steps = current_step + interval
self.train(steps=num_steps, checkpoint_at_completion=False)
self.evaluate(steps=eval_steps)
current_step = self.global_step.numpy()
if current_step < train_steps:
self.evaluate(steps=eval_steps)
self.evaluate(steps=eval_steps)
self._maybe_save_checkpoint(check_interval=False)

def evaluate_continuously(self,
Expand Down
33 changes: 0 additions & 33 deletions orbit/controller_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,39 +807,6 @@ def steps_per_loop_fn(global_step):
test_controller.train(steps=10)
self.assertEqual(test_runner.global_step, 10)

def test_final_evaluation_is_done_even_no_training_needed(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer
)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10,
)
summary_dir = os.path.join(self.model_dir, "summaries")
summary_manager = orbit.utils.SummaryManager(
summary_dir, tf.summary.scalar, global_step=test_runner.global_step
)

test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager,
summary_manager=summary_manager,
)

test_controller.train_and_evaluate(
train_steps=0, eval_interval=2
)
self.assertNotEmpty(tf.io.gfile.listdir(summary_dir))
self.assertNotEmpty(
summaries_with_matching_keyword("eval_loss", summary_dir)
)

if __name__ == "__main__":
tf.test.main()

0 comments on commit 03c9bfb

Please sign in to comment.