Skip to content

Commit

Permalink
Adding summaries to the resnet example.
Browse files Browse the repository at this point in the history
Also utilities to use summaries in graph mode.

PiperOrigin-RevId: 173483424
  • Loading branch information
alextp authored and tensorflower-gardener committed Oct 26, 2017
1 parent 6149fec commit ff7b9a6
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 42 deletions.
2 changes: 1 addition & 1 deletion tensorflow/contrib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ py_library(
"//tensorflow/contrib/staging",
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
"//tensorflow/contrib/stateless",
"//tensorflow/contrib/summary:summary_ops",
"//tensorflow/contrib/summary:summary",
"//tensorflow/contrib/tensor_forest:init_py",
"//tensorflow/contrib/tensorboard",
"//tensorflow/contrib/testing:testing_py",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from tensorflow.contrib.ndlstm import python as ndlstm
from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph
from tensorflow.contrib.specs import python as specs
from tensorflow.contrib.summary import summary

from tensorflow.python.util.lazy_loader import LazyLoader
ffmpeg = LazyLoader("ffmpeg",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/cmake/tf_core_ops.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ set(tf_op_lib_names
"state_ops"
"stateless_random_ops"
"string_ops"
"summary_ops"
"training_ops"
)

Expand Down
3 changes: 3 additions & 0 deletions tensorflow/contrib/cmake/tf_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ add_python_module("tensorflow/contrib/reduce_slice_ops/ops")
add_python_module("tensorflow/contrib/reduce_slice_ops/python")
add_python_module("tensorflow/contrib/reduce_slice_ops/python/kernel_tests")
add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops")
add_python_module("tensorflow/contrib/summary")

# Generate the tensorflow.python.platform.build_info module.
set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py")
Expand Down Expand Up @@ -812,6 +813,8 @@ GENERATE_PYTHON_OP_LIB("stateless_random_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py)
GENERATE_PYTHON_OP_LIB("debug_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py)
GENERATE_PYTHON_OP_LIB("summary_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/summary/gen_summary_ops.py)

add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES})
add_dependencies(tf_python_ops tf_python_op_gen_main)
Expand Down
25 changes: 25 additions & 0 deletions tensorflow/contrib/summary/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":summary_ops",
":summary_test_util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:lib",
Expand Down Expand Up @@ -52,6 +53,16 @@ py_library(
],
)

py_library(
name = "summary",
srcs = ["summary.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":summary_ops",
],
)

filegroup(
name = "all_files",
srcs = glob(
Expand All @@ -63,3 +74,17 @@ filegroup(
),
visibility = ["//tensorflow:__subpackages__"],
)

# NOTE: target cannot be testonly because it needs to be in the pip
# package. Sigh.
py_library(
name = "summary_test_util",
srcs = ["summary_test_util.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
],
)
39 changes: 39 additions & 0 deletions tensorflow/contrib/summary/summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Contrib summary package.
The operations in this package are safe to use with eager execution turned or on
off.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=unused-import
from tensorflow.contrib.summary.summary_ops import all_summary_ops
from tensorflow.contrib.summary.summary_ops import always_record_summaries
from tensorflow.contrib.summary.summary_ops import audio
from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
from tensorflow.contrib.summary.summary_ops import generic
from tensorflow.contrib.summary.summary_ops import histogram
from tensorflow.contrib.summary.summary_ops import image
from tensorflow.contrib.summary.summary_ops import never_record_summaries
from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps
from tensorflow.contrib.summary.summary_ops import scalar
from tensorflow.contrib.summary.summary_ops import should_record_summaries
from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op
82 changes: 63 additions & 19 deletions tensorflow/contrib/summary/summary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import summary_op_util
from tensorflow.python.training import training_util
from tensorflow.python.util import tf_contextlib
Expand All @@ -33,6 +35,9 @@
# Tensor. If this tensor is True the summary ops will record summaries.
_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"

_SUMMARY_COLLECTION_NAME = "_SUMMARY_V2"
_SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2"


def should_record_summaries():
"""Returns boolean Tensor which is true if summaries should be recorded."""
Expand Down Expand Up @@ -78,10 +83,15 @@ def never_record_summaries():


class SummaryWriter(object):
"""Encapsulates a summary writer."""

def __init__(self, resource):
self._resource = resource

def __del__(self):
if context.in_eager_mode():
resource_variable_ops.destroy_resource_op(self._resource)

def set_as_default(self):
context.context().summary_writer_resource = self._resource

Expand All @@ -90,6 +100,9 @@ def as_default(self):
old = context.context().summary_writer_resource
context.context().summary_writer_resource = self._resource
yield
# Flushes the summary writer in eager mode or in graph functions, but not in
# legacy graph mode (you're on your own there).
gen_summary_ops.flush_summary_writer(self._resource)
context.context().summary_writer_resource = old


Expand All @@ -108,14 +121,33 @@ def create_summary_file_writer(logdir,
resource = gen_summary_ops.summary_writer(shared_name=name)
# TODO(apassos) ensure the initialization op runs when in graph mode; consider
# calling session.run here.
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
flush_secs, filename_suffix)
ops.add_to_collection(
_SUMMARY_WRITER_INIT_COLLECTION_NAME,
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
flush_secs, filename_suffix))
return SummaryWriter(resource)


def _nothing():
"""Convenient else branch for when summaries do not record."""
return False
return constant_op.constant(False)


def all_summary_ops():
"""Graph-mode only. Returns all summary ops."""
if context.in_eager_mode():
raise RuntimeError(
"tf.contrib.summary.all_summary_ops is only supported in graph mode.")
return ops.get_collection(_SUMMARY_COLLECTION_NAME)


def summary_writer_initializer_op():
"""Graph-mode only. Returns the list of ops to create all summary writers."""
if context.in_eager_mode():
raise RuntimeError(
"tf.contrib.summary.summary_writer_initializer_op is only "
"supported in graph mode.")
return ops.get_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME)


def summary_writer_function(name, tensor, function, family=None):
Expand All @@ -133,30 +165,37 @@ def summary_writer_function(name, tensor, function, family=None):
def record():
with summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
function(tag, scope)
return True
with ops.control_dependencies([function(tag, scope)]):
return constant_op.constant(True)

return utils.smart_cond(
should_record_summaries(), record, _nothing, name="")
with ops.device("cpu:0"):
op = utils.smart_cond(
should_record_summaries(), record, _nothing, name="")
ops.add_to_collection(_SUMMARY_COLLECTION_NAME, op)
return op


def generic(name, tensor, metadata, family=None):
"""Writes a tensor summary if possible."""

def function(tag, scope):
gen_summary_ops.write_summary(context.context().summary_writer_resource,
training_util.get_global_step(), tensor,
tag, metadata, name=scope)
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_summary(
context.context().summary_writer_resource,
training_util.get_global_step(), array_ops.identity(tensor),
tag, metadata, name=scope)
return summary_writer_function(name, tensor, function, family=family)


def scalar(name, tensor, family=None):
"""Writes a scalar summary if possible."""

def function(tag, scope):
gen_summary_ops.write_scalar_summary(
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_scalar_summary(
context.context().summary_writer_resource,
training_util.get_global_step(), tag, tensor, name=scope)
training_util.get_global_step(), tag, array_ops.identity(tensor),
name=scope)

return summary_writer_function(name, tensor, function, family=family)

Expand All @@ -165,9 +204,11 @@ def histogram(name, tensor, family=None):
"""Writes a histogram summary if possible."""

def function(tag, scope):
gen_summary_ops.write_histogram_summary(
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_histogram_summary(
context.context().summary_writer_resource,
training_util.get_global_step(), tag, tensor, name=scope)
training_util.get_global_step(), tag, array_ops.identity(tensor),
name=scope)

return summary_writer_function(name, tensor, function, family=family)

Expand All @@ -178,10 +219,12 @@ def image(name, tensor, bad_color=None, max_images=3, family=None):
def function(tag, scope):
if bad_color is None:
bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
gen_summary_ops.write_image_summary(
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_image_summary(
context.context().summary_writer_resource,
training_util.get_global_step(), tag, tensor, bad_color_, max_images,
name=scope)
training_util.get_global_step(), tag, array_ops.identity(tensor),
bad_color_,
max_images, name=scope)

return summary_writer_function(name, tensor, function, family=family)

Expand All @@ -190,11 +233,12 @@ def audio(name, tensor, sample_rate, max_outputs, family=None):
"""Writes an audio summary if possible."""

def function(tag, scope):
gen_summary_ops.write_audio_summary(
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_audio_summary(
context.context().summary_writer_resource,
training_util.get_global_step(),
tag,
tensor,
array_ops.identity(tensor),
sample_rate=sample_rate,
max_outputs=max_outputs,
name=scope)
Expand Down
29 changes: 7 additions & 22 deletions tensorflow/contrib/summary/summary_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
from __future__ import division
from __future__ import print_function

import os
import tempfile

from tensorflow.contrib.summary import summary_ops
from tensorflow.core.util import event_pb2
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util

Expand Down Expand Up @@ -71,16 +69,9 @@ def write():
summary_ops.scalar('scalar', 2.0)

write()

self.assertTrue(gfile.Exists(logdir))
files = gfile.ListDirectory(logdir)
self.assertEqual(len(files), 1)
records = list(
tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
self.assertEqual(len(records), 2)
event = event_pb2.Event()
event.ParseFromString(records[1])
self.assertEqual(event.summary.value[0].simple_value, 2.0)
events = summary_test_util.events_from_file(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].simple_value, 2.0)

def testSummaryName(self):
training_util.get_or_create_global_step()
Expand All @@ -91,15 +82,9 @@ def testSummaryName(self):

summary_ops.scalar('scalar', 2.0)

self.assertTrue(gfile.Exists(logdir))
files = gfile.ListDirectory(logdir)
self.assertEqual(len(files), 1)
records = list(
tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
self.assertEqual(len(records), 2)
event = event_pb2.Event()
event.ParseFromString(records[1])
self.assertEqual(event.summary.value[0].tag, 'scalar')
events = summary_test_util.events_from_file(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'scalar')


if __name__ == '__main__':
Expand Down
41 changes: 41 additions & 0 deletions tensorflow/contrib/summary/summary_test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utilities to test summaries."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from tensorflow.core.util import event_pb2
from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile


def events_from_file(logdir):
"""Returns all events in the single eventfile in logdir."""
assert gfile.Exists(logdir)
files = gfile.ListDirectory(logdir)
assert len(files) == 1, "Found more than one file in logdir: %s" % files
records = list(
tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
result = []
for r in records:
event = event_pb2.Event()
event.ParseFromString(r)
result.append(event)
return result
Loading

0 comments on commit ff7b9a6

Please sign in to comment.