Skip to content

Commit

Permalink
removed duplicate steps_per_execution variable
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 390410210
  • Loading branch information
tensorflower-gardener committed Aug 12, 2021
1 parent d8f7958 commit e6ba837
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions keras/engine/data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,11 +1144,9 @@ def __init__(self,
# `steps_per_execution` is mutable and may be changed by the DataAdapter
# to handle partial executions.
if steps_per_execution is None:
self._steps_per_execution = 1
self._steps_per_execution_value = 1
self._steps_per_execution = tf.Variable(1)
else:
self._steps_per_execution = steps_per_execution
self._steps_per_execution_value = steps_per_execution.numpy().item()

adapter_cls = select_data_adapter(x, y)
self._adapter = adapter_cls(
Expand All @@ -1168,7 +1166,7 @@ def __init__(self,
strategy = tf.distribute.get_strategy()

self._current_step = 0
self._step_increment = self._steps_per_execution_value - 1
self._step_increment = self._steps_per_execution.numpy().item() - 1
self._insufficient_data = False

self._configure_dataset_and_inferred_steps(strategy, x, steps_per_epoch,
Expand Down Expand Up @@ -1207,17 +1205,15 @@ def _truncate_execution_to_epoch(self):
"""Truncates steps per execution to at most one epoch."""
should_truncate = (
self._inferred_steps is not None and
self._steps_per_execution_value > self._inferred_steps)
original_value = self._steps_per_execution_value
self._steps_per_execution.numpy().item() > self._inferred_steps)
original_value = self._steps_per_execution.numpy().item()
try:
if should_truncate:
self._steps_per_execution.assign(self._inferred_steps)
self._steps_per_execution_value = self._inferred_steps
yield
finally:
if should_truncate:
self._steps_per_execution.assign(original_value)
self._steps_per_execution_value = original_value

def sync(self):
context.async_wait()
Expand Down Expand Up @@ -1250,25 +1246,25 @@ def steps(self):
self._current_step < self._inferred_steps):
if self._insufficient_data: # Set by `catch_stop_iteration`.
break

original_spe = self._steps_per_execution.numpy().item()
can_run_full_execution = (
self._steps_per_execution_value == 1 or
original_spe == 1 or
self._inferred_steps is None or
self._inferred_steps - self._current_step >=
self._steps_per_execution_value)
original_spe)

if can_run_full_execution:
self._step_increment = self._steps_per_execution_value - 1
self._step_increment = original_spe - 1
yield self._current_step
self._current_step += self._steps_per_execution_value
self._current_step += original_spe
else:
# Last partial execution.
steps_remaining = self._inferred_steps - self._current_step
self._steps_per_execution.assign(steps_remaining)
self._step_increment = steps_remaining - 1
yield self._current_step
self._current_step += steps_remaining
self._steps_per_execution.assign(self._steps_per_execution_value)
self._steps_per_execution.assign(original_spe)

@property
def step_increment(self):
Expand Down Expand Up @@ -1334,7 +1330,8 @@ def _samples(self):

def _validate_data_handler(self):
# TODO(b/152094471): Support this with DistIter.get_next_as_optional.
if self._steps_per_execution_value > 1 and self._inferred_steps is None:
if self._steps_per_execution.numpy().item(
) > 1 and self._inferred_steps is None:
raise ValueError(
"Could not infer the size of the data. With "
"`steps_per_execution > 1`, you must specify the number of steps "
Expand Down

0 comments on commit e6ba837

Please sign in to comment.