Skip to content

Commit

Permalink
update exporte changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
tombstone committed Oct 28, 2017
1 parent 9a811d9 commit 78cf0ae
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 40 deletions.
27 changes: 19 additions & 8 deletions research/object_detection/export_inference_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@
flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be '
'one of [`image_tensor`, `encoded_image_string_tensor`, '
'`tf_example`]')
flags.DEFINE_list('input_shape', None,
'If input_type is `image_tensor`, this can explicitly set '
'the shape of this input tensor to a fixed size. The '
'dimensions are to be provided as a comma-separated list of '
'integers. A value of -1 can be used for unknown dimensions. '
'If not specified, for an `image_tensor, the default shape '
'will be partially specified as `[None, None, None, 3]`.')
flags.DEFINE_string('pipeline_config_path', None,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file.')
Expand All @@ -85,21 +92,25 @@
'path/to/model.ckpt')
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')

tf.app.flags.mark_flag_as_required('pipeline_config_path')
tf.app.flags.mark_flag_as_required('trained_checkpoint_prefix')
tf.app.flags.mark_flag_as_required('output_directory')
FLAGS = flags.FLAGS


def main(_):
assert FLAGS.pipeline_config_path, '`pipeline_config_path` is missing'
assert FLAGS.trained_checkpoint_prefix, (
'`trained_checkpoint_prefix` is missing')
assert FLAGS.output_directory, '`output_directory` is missing'

pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
exporter.export_inference_graph(
FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_prefix,
FLAGS.output_directory)
if FLAGS.input_shape:
input_shape = [
int(dim) if dim != '-1' else None for dim in FLAGS.input_shape
]
else:
input_shape = None
exporter.export_inference_graph(FLAGS.input_type, pipeline_config,
FLAGS.trained_checkpoint_prefix,
FLAGS.output_directory, input_shape)


if __name__ == '__main__':
Expand Down
101 changes: 76 additions & 25 deletions research/object_detection/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
"""Functions to export object detection inference graph."""
import logging
import os
import tempfile
import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import saver as saver_lib
Expand All @@ -43,7 +43,7 @@ def freeze_graph_with_def_protos(
filename_tensor_name,
clear_devices,
initializer_nodes,
optimize_graph=False,
optimize_graph=True,
variable_names_blacklist=''):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
Expand Down Expand Up @@ -111,20 +111,45 @@ def freeze_graph_with_def_protos(
return output_graph_def


def replace_variable_values_with_moving_averages(graph,
current_checkpoint_file,
new_checkpoint_file):
"""Replaces variable values in the checkpoint with their moving averages.
def _image_tensor_input_placeholder():
"""Returns placeholder and input node that accepts a batch of uint8 images."""
input_tensor = tf.placeholder(dtype=tf.uint8,
shape=(None, None, None, 3),
name='image_tensor')
If the current checkpoint has shadow variables maintaining moving averages of
the variables defined in the graph, this function generates a new checkpoint
where the variables contain the values of their moving averages.
Args:
graph: a tf.Graph object.
current_checkpoint_file: a checkpoint containing both original variables and
their moving averages.
new_checkpoint_file: file path to write a new checkpoint.
"""
with graph.as_default():
variable_averages = tf.train.ExponentialMovingAverage(0.0)
ema_variables_to_restore = variable_averages.variables_to_restore()
with tf.Session() as sess:
read_saver = tf.train.Saver(ema_variables_to_restore)
read_saver.restore(sess, current_checkpoint_file)
write_saver = tf.train.Saver()
write_saver.save(sess, new_checkpoint_file)


def _image_tensor_input_placeholder(input_shape=None):
"""Returns input placeholder and a 4-D uint8 image tensor."""
if input_shape is None:
input_shape = (None, None, None, 3)
input_tensor = tf.placeholder(
dtype=tf.uint8, shape=input_shape, name='image_tensor')
return input_tensor, input_tensor


def _tf_example_input_placeholder():
"""Returns input that accepts a batch of strings with tf examples.
Returns:
a tuple of placeholder and input nodes that output decoded images.
a tuple of input placeholder and the output decoded images.
"""
batch_tf_example_placeholder = tf.placeholder(
tf.string, shape=[None], name='tf_example')
Expand All @@ -145,7 +170,7 @@ def _encoded_image_string_tensor_input_placeholder():
"""Returns input that accepts a batch of PNG or JPEG strings.
Returns:
a tuple of placeholder and input nodes that output decoded images.
a tuple of input placeholder and the output decoded images.
"""
batch_image_str_placeholder = tf.placeholder(
dtype=tf.string,
Expand Down Expand Up @@ -301,7 +326,9 @@ def _export_inference_graph(input_type,
use_moving_averages,
trained_checkpoint_prefix,
output_directory,
optimize_graph=False,
additional_output_tensor_names=None,
input_shape=None,
optimize_graph=True,
output_collection_name='inference_op'):
"""Export helper."""
tf.gfile.MakeDirs(output_directory)
Expand All @@ -312,50 +339,69 @@ def _export_inference_graph(input_type,

if input_type not in input_placeholder_fn_map:
raise ValueError('Unknown input type: {}'.format(input_type))
placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type]()
placeholder_args = {}
if input_shape is not None:
if input_type != 'image_tensor':
raise ValueError('Can only specify input shape for `image_tensor` '
'inputs.')
placeholder_args['input_shape'] = input_shape
placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
**placeholder_args)
inputs = tf.to_float(input_tensors)
preprocessed_inputs = detection_model.preprocess(inputs)
output_tensors = detection_model.predict(preprocessed_inputs)
postprocessed_tensors = detection_model.postprocess(output_tensors)
outputs = _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name)
# Add global step to the graph.
slim.get_or_create_global_step()

saver = None
if use_moving_averages:
variable_averages = tf.train.ExponentialMovingAverage(0.0)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
temp_checkpoint_file = tempfile.NamedTemporaryFile()
replace_variable_values_with_moving_averages(
tf.get_default_graph(), trained_checkpoint_prefix,
temp_checkpoint_file.name)
checkpoint_to_use = temp_checkpoint_file.name
else:
saver = tf.train.Saver()
checkpoint_to_use = trained_checkpoint_prefix

saver = tf.train.Saver()
input_saver_def = saver.as_saver_def()

_write_graph_and_checkpoint(
inference_graph_def=tf.get_default_graph().as_graph_def(),
model_path=model_path,
input_saver_def=input_saver_def,
trained_checkpoint_prefix=trained_checkpoint_prefix)
trained_checkpoint_prefix=checkpoint_to_use)

if additional_output_tensor_names is not None:
output_node_names = ','.join(outputs.keys()+additional_output_tensor_names)
else:
output_node_names = ','.join(outputs.keys())

frozen_graph_def = freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=trained_checkpoint_prefix,
output_node_names=','.join(outputs.keys()),
input_checkpoint=checkpoint_to_use,
output_node_names=output_node_names,
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
clear_devices=True,
optimize_graph=optimize_graph,
initializer_nodes='')
_write_frozen_graph(frozen_graph_path, frozen_graph_def)
_write_saved_model(saved_model_path, frozen_graph_def, placeholder_tensor,
outputs)
_write_saved_model(saved_model_path, frozen_graph_def,
placeholder_tensor, outputs)


def export_inference_graph(input_type,
pipeline_config,
trained_checkpoint_prefix,
output_directory,
optimize_graph=False,
output_collection_name='inference_op'):
input_shape=None,
optimize_graph=True,
output_collection_name='inference_op',
additional_output_tensor_names=None):
"""Exports inference graph for the model specified in the pipeline config.
Args:
Expand All @@ -364,13 +410,18 @@ def export_inference_graph(input_type,
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
trained_checkpoint_prefix: Path to the trained checkpoint file.
output_directory: Path to write outputs.
input_shape: Sets a fixed shape for an `image_tensor` input. If not
specified, will default to [None, None, None, 3].
optimize_graph: Whether to optimize graph using Grappler.
output_collection_name: Name of collection to add output tensors to.
If None, does not add output tensors to a collection.
additional_output_tensor_names: list of additional output
tensors to include in the frozen graph.
"""
detection_model = model_builder.build(pipeline_config.model,
is_training=False)
_export_inference_graph(input_type, detection_model,
pipeline_config.eval_config.use_moving_averages,
trained_checkpoint_prefix, output_directory,
optimize_graph, output_collection_name)
trained_checkpoint_prefix,
output_directory, additional_output_tensor_names,
input_shape, optimize_graph, output_collection_name)
Loading

0 comments on commit 78cf0ae

Please sign in to comment.