Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 295849975
  • Loading branch information
tensorflower-gardener committed Feb 19, 2020
1 parent f3600cd commit acea25b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 9 deletions.
17 changes: 13 additions & 4 deletions official/modeling/training/distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class SummaryWriter(object):
"""Simple SummaryWriter for writing dictionary of metrics.
Attributes:
_writer: The tf.SummaryWriter.
writer: The tf.SummaryWriter.
"""

def __init__(self, model_dir: Text, name: Text):
Expand All @@ -74,7 +74,7 @@ def __init__(self, model_dir: Text, name: Text):
model_dir: the model folder path.
name: the summary subfolder name.
"""
self._writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
self.writer = tf.summary.create_file_writer(os.path.join(model_dir, name))

def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
"""Write metrics to summary with the given writer.
Expand All @@ -88,10 +88,10 @@ def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
logging.warning('Warning: summary writer prefer metrics as dictionary.')
metrics = {'metric': metrics}

with self._writer.as_default():
with self.writer.as_default():
for k, v in metrics.items():
tf.summary.scalar(k, v, step=step)
self._writer.flush()
self.writer.flush()


class DistributedExecutor(object):
Expand Down Expand Up @@ -122,6 +122,9 @@ def __init__(self,
self._strategy = strategy
self._checkpoint_name = 'ctl_step_{step}.ckpt'
self._is_multi_host = is_multi_host
self.train_summary_writer = None
self.eval_summary_writer = None
self.global_train_step = None

@property
def checkpoint_name(self):
Expand Down Expand Up @@ -395,7 +398,10 @@ def _run_callbacks_on_batch_end(batch):
eval_metric = eval_metric_fn()
train_metric = train_metric_fn()
train_summary_writer = summary_writer_fn(model_dir, 'eval_train')
self.train_summary_writer = train_summary_writer.writer

test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
self.eval_summary_writer = test_summary_writer.writer

# Continue training loop.
train_step = self._create_train_step(
Expand All @@ -406,6 +412,7 @@ def _run_callbacks_on_batch_end(batch):
metric=train_metric)
test_step = None
if eval_input_fn and eval_metric:
self.global_train_step = model.optimizer.iterations
test_step = self._create_test_step(strategy, model, metric=eval_metric)

logging.info('Training started')
Expand Down Expand Up @@ -549,6 +556,7 @@ def terminate_eval():
return True

summary_writer = summary_writer_fn(model_dir, 'eval')
self.eval_summary_writer = summary_writer.writer

# Read checkpoints from the given model directory
# until `eval_timeout` seconds elapses.
Expand Down Expand Up @@ -615,6 +623,7 @@ def evaluate_checkpoint(self,
'checkpoint', checkpoint_path)
checkpoint.restore(checkpoint_path)

self.global_train_step = model.optimizer.iterations
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
Expand Down
3 changes: 3 additions & 0 deletions official/vision/detection/configs/retinanet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
'val_json_file': '',
'eval_file_pattern': '',
'input_sharding': True,
# When visualizing images, set evaluation batch size to 40 to avoid
# potential OOM.
'num_images_to_visualize': 0,
},
'predict': {
'predict_batch_size': 8,
Expand Down
35 changes: 30 additions & 5 deletions official/vision/detection/executor/detection_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import json
import tensorflow.compat.v2 as tf
from official.modeling.training import distributed_executor as executor
from official.vision.detection.utils import box_utils


class DetectionDistributedExecutor(executor.DistributedExecutor):
Expand All @@ -38,13 +39,19 @@ def __init__(self,
trainable_variables_filter=None,
**kwargs):
super(DetectionDistributedExecutor, self).__init__(**kwargs)
params = kwargs['params']
if predict_post_process_fn:
assert callable(predict_post_process_fn)
if trainable_variables_filter:
assert callable(trainable_variables_filter)
self._predict_post_process_fn = predict_post_process_fn
self._trainable_variables_filter = trainable_variables_filter
self.eval_steps = tf.Variable(
0,
trainable=False,
dtype=tf.int32,
synchronization=tf.VariableSynchronization.ON_READ,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])

def _create_replicated_step(self,
strategy,
Expand Down Expand Up @@ -90,31 +97,49 @@ def _create_test_step(self, strategy, model, metric):
"""Creates a distributed test step."""

@tf.function
def test_step(iterator):
def test_step(iterator, eval_steps):
"""Calculates evaluation metrics on distributed devices."""

def _test_step_fn(inputs):
def _test_step_fn(inputs, eval_steps):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
if self._predict_post_process_fn:
labels, prediction_outputs = self._predict_post_process_fn(
labels, model_outputs)
num_remaining_visualizations = (
self._params.eval.num_images_to_visualize - eval_steps)
# If there are remaining number of visualizations that needs to be
# done, add next batch outputs for visualization.
#
# TODO(hongjunchoi): Once dynamic slicing is supported on TPU, only
# write correct slice of outputs to summary file.
if num_remaining_visualizations > 0:
box_utils.visualize_bounding_boxes(
inputs, prediction_outputs['detection_boxes'],
self.global_train_step, self.eval_summary_writer)

return labels, prediction_outputs

labels, outputs = strategy.experimental_run_v2(
_test_step_fn, args=(next(iterator),))
_test_step_fn, args=(
next(iterator),
eval_steps,
))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs)
labels = tf.nest.map_structure(strategy.experimental_local_results,
labels)

eval_steps.assign_add(self._params.eval.batch_size)
return labels, outputs

return test_step

def _run_evaluation(self, test_step, current_training_step, metric,
test_iterator):
"""Runs validation steps and aggregate metrics."""
self.eval_steps.assign(0)
if not test_iterator or not metric:
logging.warning(
'Both test_iterator (%s) and metrics (%s) must not be None.',
Expand All @@ -123,7 +148,7 @@ def _run_evaluation(self, test_step, current_training_step, metric,
logging.info('Running evaluation after step: %s.', current_training_step)
while True:
try:
labels, outputs = test_step(test_iterator)
labels, outputs = test_step(test_iterator, self.eval_steps)
if metric:
metric.update_state(labels, outputs)
except (StopIteration, tf.errors.OutOfRangeError):
Expand Down
1 change: 1 addition & 0 deletions official/vision/detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,5 @@ def main(argv):

if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.config.set_soft_device_placement(True)
app.run(main)
16 changes: 16 additions & 0 deletions official/vision/detection/utils/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@
BBOX_XFORM_CLIP = np.log(1000. / 16.)


def visualize_images_with_bounding_boxes(images, box_outputs, step,
summary_writer):
"""Records subset of evaluation images with bounding boxes."""
image_shape = tf.shape(images[0])
image_height = tf.cast(image_shape[0], tf.float32)
image_width = tf.cast(image_shape[1], tf.float32)
normalized_boxes = normalize_boxes(box_outputs, [image_height, image_width])

bounding_box_color = tf.constant([[1.0, 1.0, 0.0, 1.0]])
image_summary = tf.image.draw_bounding_boxes(images, normalized_boxes,
bounding_box_color)
with summary_writer.as_default():
tf.summary.image('bounding_box_summary', image_summary, step=step)
summary_writer.flush()


def yxyx_to_xywh(boxes):
"""Converts boxes from ymin, xmin, ymax, xmax to xmin, ymin, width, height.
Expand Down

0 comments on commit acea25b

Please sign in to comment.