Skip to content

Commit

Permalink
improved model support
Browse files Browse the repository at this point in the history
  • Loading branch information
jaybdub committed Oct 3, 2018
1 parent 528d420 commit ad12439
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 150 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ classes, image size, non-max supression parameters, but the performance may vary
TODO
----

- [ ] update download_detection_model to download latest from tensorflow/models
- [ ] add patches to download_detection_model to fix config version issue with mobilenet_v2 ssd model
- [ ] add model re-export to build_detection_graph to resolve TF version discrepancies
- [ ] add support for batch size > 1 to build_detection_graph
- [ ] add support for NMS score threshold parameter to build_detection_graph
- [x] update download_detection_model to download latest from tensorflow/models
- [x] add patches to download_detection_model to fix config version issue with mobilenet_v2 ssd model
- [x] add model re-export to build_detection_graph to resolve TF version discrepancies
- [x] add support for batch size > 1 to build_detection_graph
- [x] add support for NMS score threshold parameter to build_detection_graph
282 changes: 137 additions & 145 deletions tf_trt_models/detection.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,160 @@
from object_detection.protos.pipeline_pb2 import TrainEvalPipelineConfig
from object_detection.builders import model_builder
from object_detection.protos import pipeline_pb2
from object_detection import exporter

import os
import tarfile
import subprocess

from collections import namedtuple
from google.protobuf import text_format

import tensorflow as tf

from .graph_utils import convert_relu6, remove_op

input_name = 'input'
output_map = {
'detection_scores': 'scores',
'detection_boxes': 'boxes',
'detection_classes': 'classes',
'detection_masks': 'masks'
}

nets = {
'ssd_mobilenet_v1_coco': {
'config_url': 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/samples/configs/ssd_mobilenet_v1_coco.config',
'checkpoint_url': 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2017_11_17.tar.gz',
},
'ssd_mobilenet_v2_coco': {
'config_url': 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config',
'checkpoint_url': 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz',
},
'ssd_inception_v2_coco': {
'config_url': 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/samples/configs/ssd_inception_v2_coco.config',
'checkpoint_url': 'http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz',
},
from .graph_utils import force_nms_cpu as f_force_nms_cpu
from .graph_utils import replace_relu6 as f_replace_relu6
from .graph_utils import remove_assert as f_remove_assert

DetectionModel = namedtuple('DetectionModel', ['name', 'url', 'extract_dir'])

INPUT_NAME='image_tensor'
BOXES_NAME='detection_boxes'
CLASSES_NAME='detection_classes'
SCORES_NAME='detection_scores'
MASKS_NAME='detection_masks'
NUM_DETECTIONS_NAME='num_detections'
FROZEN_GRAPH_NAME='frozen_inference_graph.pb'
PIPELINE_CONFIG_NAME='pipeline.config'
CHECKPOINT_PREFIX='model.ckpt'


MODELS = {
'ssd_mobilenet_v1_coco': DetectionModel(
'ssd_mobilenet_v1_coco',
'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz',
'ssd_mobilenet_v1_coco_2018_01_28',
),
'ssd_mobilenet_v2_coco': DetectionModel(
'ssd_mobilenet_v2_coco',
'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz',
'ssd_mobilenet_v2_coco_2018_03_29',
),
'ssd_inception_v2_coco': DetectionModel(
'ssd_inception_v2_coco',
'http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2018_01_28.tar.gz',
'ssd_inception_v2_coco_2018_01_28',
),
'ssd_resnet_50_fpn_coco': DetectionModel(
'ssd_resnet_50_fpn_coco',
'http://download.tensorflow.org/models/object_detection/ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz',
'ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03',
),
'faster_rcnn_resnet50_coco': DetectionModel(
'faster_rcnn_resnet50_coco',
'http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz',
'faster_rcnn_resnet50_coco_2018_01_28',
),
'faster_rcnn_nas': DetectionModel(
'faster_rcnn_nas',
'http://download.tensorflow.org/models/object_detection/faster_rcnn_nas_coco_2018_01_28.tar.gz',
'faster_rcnn_nas_coco_2018_01_28',
),
'mask_rcnn_resnet50_atrous_coco': DetectionModel(
'mask_rcnn_resnet50_atrous_coco',
'http://download.tensorflow.org/models/object_detection/mask_rcnn_resnet50_atrous_coco_2018_01_28.tar.gz',
'mask_rcnn_resnet50_atrous_coco_2018_01_28',
)
}

def download_detection_model(model, output_dir='.'):
"""Download a default detection model configuration and checkpoint.
This function downloads a default detection model configuration and
checkpoint. This is only available for a subset of models in the
TensorFlow object detection model zoo that are known to work on Jetson.
The following models are available
ssd_mobilenet_v1_coco
ssd_mobilenet_v2_coco
ssd_inception_v2_coco
:param model: the model name from the above list
:type model: string
:param output_dir: the directory where files are downloaded to
:type output_dir: string
:return config_path: path to the object detection pipeline config file
:rtype string
:return checkpoint_path: path to the checkpoint files prefix containing trained model params
:rtype string
"""
global nets
config_path = ''
checkpoint_path = ''

if not os.path.exists(output_dir):
os.makedirs(output_dir)

modeldir_path = os.path.join(output_dir, model)
if not os.path.exists(modeldir_path):
os.makedirs(modeldir_path)

config_path = os.path.join(output_dir, model + '.config')
if not os.path.isfile(config_path):
subprocess.call(['wget', '--no-check-certificate', nets[model]['config_url'], '-O', config_path])

modeltar_path = os.path.join(output_dir, os.path.basename(nets[model]['checkpoint_url']))
if not os.path.isfile(modeltar_path):
subprocess.call(['wget', '--no-check-certificate', nets[model]['checkpoint_url'], '-O', modeltar_path])

tar_file = tarfile.open(modeltar_path)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'model.ckpt' in file_name:
file.name = file_name
tar_file.extract(file, modeldir_path)

checkpoint_path = os.path.join(modeldir_path, 'model.ckpt')

return config_path, checkpoint_path

def build_detection_graph(config, checkpoint):
"""Build an object detection model from the TensorFlow model zoo.
This function creates an object detection model, sourced from the
TensorFlow object detection API.

It is necessary to use this function to generate a frozen graph that is
compatible with TensorFlow/TensorRT integration. In addition to generating
a graph that is compatible with TensorFlow's TensorRT package, this
function performs other graph modifications, such as forced device
placement, that improve performance on Jetson. These graph modifications
are tested with a subset of the object detection API and may or may not
work well with models not listed.
def get_input_names(model):
return [INPUT_NAME]

The workflow when using this method is:

1. Train model using TensorFlow object detection API
2. Build graph configured for Jetson using this function
3. Optimize the graph output by this method with the TensorRT package in
TensorFlow
4. Execute in regular TensorFlow, or using the high level TFModel class
def get_output_names(model):
output_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME]
if model == 'mask_rcnn_resnet50_atrous_coco':
output_names.append(MASKS_NAME)
return output_names

:param config: path to the object detection pipeline config file
:type config: string
:param checkpoint: path to the checkpoint files prefix containing trained model params
:type checkpoint: string
:returns: the configured frozen graph representing object detection model
:rtype: a tensorflow GraphDef
"""
global input_name, output_map

if isinstance(config, str):
with open(config, 'r') as f:
config_str = f.read()
config = TrainEvalPipelineConfig()
text_format.Merge(config_str, config)


tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True

with tf.Graph().as_default() as tf_graph:
with tf.Session(config=tf_config) as tf_sess:

model = model_builder.build(model_config=config.model, is_training=False)
def download_detection_model(model, output_dir='.'):
"""Downloads a pre-trained object detection model"""
global MODELS

tf_input = tf.placeholder(tf.float32, [1, None, None, 3], name=input_name)
tf_preprocessed, tf_true_image_shapes = model.preprocess(tf_input)
tf_predictions = model.predict(preprocessed_inputs=tf_preprocessed,
true_image_shapes=tf_true_image_shapes)
tf_postprocessed = model.postprocess(
prediction_dict=tf_predictions,
true_image_shapes=tf_true_image_shapes
)
model_name = model

tf_saver = tf.train.Saver()
tf_saver.restore(save_path=checkpoint, sess=tf_sess)
model = MODELS[model_name]
subprocess.call(['mkdir', '-p', output_dir])
tar_file = os.path.join(output_dir, os.path.basename(model.url))

outputs = {}
for key, op in tf_postprocessed.items():
if key in output_map.keys():
outputs[output_map[key]] = \
tf.identity(op, name=output_map[key])
config_path = os.path.join(output_dir, model.extract_dir, PIPELINE_CONFIG_NAME)
checkpoint_path = os.path.join(output_dir, model.extract_dir, CHECKPOINT_PREFIX)

frozen_graph = tf.graph_util.convert_variables_to_constants(
tf_sess,
tf_sess.graph_def,
output_node_names=list(outputs.keys())
)
if not os.path.exists(os.path.join(output_dir, model.extract_dir)):
subprocess.call(['wget', model.url, '-O', tar_file])
subprocess.call(['tar', '-xzf', tar_file, '-C', output_dir])

frozen_graph = convert_relu6(frozen_graph)
# hack fix to handle mobilenet_v2 config bug
subprocess.call(['sed', '-i', '/batch_norm_trainable/d', config_path])

remove_op(frozen_graph, 'Assert')
return config_path, checkpoint_path

# force CPU device placement for NMS ops
for node in frozen_graph.node:
if 'NonMaxSuppression' in node.name:
node.device = '/device:CPU:0'

return frozen_graph, [input_name], list(outputs.keys())
def build_detection_graph(config, checkpoint,
batch_size=1,
score_threshold=None,
force_nms_cpu=True,
replace_relu6=True,
remove_assert=True,
output_dir='.generated_model'):
"""Builds a frozen graph for a pre-trained object detection model"""

config_path = config
checkpoint_path = checkpoint

# parse config from file
config = pipeline_pb2.TrainEvalPipelineConfig()
with open(config_path, 'r') as f:
text_format.Merge(f.read(), config, allow_unknown_extension=True)

# override some config parameters
if config.model.HasField('ssd'):
config.model.ssd.feature_extractor.override_base_feature_extractor_hyperparams = True
if score_threshold is not None:
config.model.ssd.post_processing.batch_non_max_suppression.score_threshold = score_threshold
elif config.model.HasField('faster_rcnn'):
if score_threshold is not None:
config.model.faster_rcnn.second_stage_post_processing.score_threshold = score_threshold

# export inference graph to file (initial)
with tf.Graph().as_default() as tf_graph:
exporter.export_inference_graph(
'image_tensor',
config,
checkpoint_path,
output_dir,
input_shape=[batch_size, None, None, 3],
write_inference_graph=False
)

# read frozen graph from file
frozen_graph = tf.GraphDef()
with open(os.path.join(output_dir, FROZEN_GRAPH_NAME), 'rb') as f:
frozen_graph.ParseFromString(f.read())

# apply graph modifications
if force_nms_cpu:
frozen_graph = f_force_nms_cpu(frozen_graph)
if replace_relu6:
frozen_graph = f_replace_relu6(frozen_graph)
if remove_assert:
frozen_graph = f_remove_assert(frozen_graph)

# get input names
# TODO: handle mask_rcnn
input_names = [INPUT_NAME]
output_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME]

# remove temporary directory
subprocess.call(['rm', '-rf', output_dir])

return frozen_graph, input_names, output_names
16 changes: 16 additions & 0 deletions tf_trt_models/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,19 @@ def remove_op(graph_def, op_name):
matches = [node for node in graph_def.node if node.op == op_name]
for match in matches:
remove_node(graph_def, match)


def force_nms_cpu(frozen_graph):
for node in frozen_graph.node:
if 'NonMaxSuppression' in node.name:
node.device = '/device:CPU:0'
return frozen_graph


def replace_relu6(frozen_graph):
return convert_relu6(frozen_graph)


def remove_assert(frozen_graph):
remove_op(frozen_graph, 'Assert')
return frozen_graph

0 comments on commit ad12439

Please sign in to comment.