Skip to content

Commit

Permalink
Merge pull request NVIDIA-AI-IOT#29 from NVIDIA-AI-IOT/improved_model…
Browse files Browse the repository at this point in the history
…_support

Improved model support
  • Loading branch information
John Welsh authored Oct 30, 2018
2 parents 72e4680 + 21376f8 commit 9ce6130
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 170 deletions.
150 changes: 119 additions & 31 deletions examples/detection/detection.ipynb

Large diffs are not rendered by default.

282 changes: 144 additions & 138 deletions tf_trt_models/detection.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,174 @@
from object_detection.protos.pipeline_pb2 import TrainEvalPipelineConfig
from object_detection.builders import model_builder
from object_detection.protos import pipeline_pb2
from object_detection.protos import image_resizer_pb2
from object_detection import exporter

import os
import tarfile
import subprocess

from collections import namedtuple
from google.protobuf import text_format

import tensorflow as tf

from .graph_utils import convert_relu6, remove_op

input_name = 'input'
output_map = {
'detection_scores': 'scores',
'detection_boxes': 'boxes',
'detection_classes': 'classes',
'detection_masks': 'masks'
from .graph_utils import force_nms_cpu as f_force_nms_cpu
from .graph_utils import replace_relu6 as f_replace_relu6
from .graph_utils import remove_assert as f_remove_assert

DetectionModel = namedtuple('DetectionModel', ['name', 'url', 'extract_dir'])

INPUT_NAME='image_tensor'
BOXES_NAME='detection_boxes'
CLASSES_NAME='detection_classes'
SCORES_NAME='detection_scores'
MASKS_NAME='detection_masks'
NUM_DETECTIONS_NAME='num_detections'
FROZEN_GRAPH_NAME='frozen_inference_graph.pb'
PIPELINE_CONFIG_NAME='pipeline.config'
CHECKPOINT_PREFIX='model.ckpt'


MODELS = {
'ssd_mobilenet_v1_coco': DetectionModel(
'ssd_mobilenet_v1_coco',
'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz',
'ssd_mobilenet_v1_coco_2018_01_28',
),
'ssd_mobilenet_v2_coco': DetectionModel(
'ssd_mobilenet_v2_coco',
'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz',
'ssd_mobilenet_v2_coco_2018_03_29',
),
'ssd_inception_v2_coco': DetectionModel(
'ssd_inception_v2_coco',
'http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2018_01_28.tar.gz',
'ssd_inception_v2_coco_2018_01_28',
),
'ssd_resnet_50_fpn_coco': DetectionModel(
'ssd_resnet_50_fpn_coco',
'http://download.tensorflow.org/models/object_detection/ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz',
'ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03',
),
'faster_rcnn_resnet50_coco': DetectionModel(
'faster_rcnn_resnet50_coco',
'http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz',
'faster_rcnn_resnet50_coco_2018_01_28',
),
'faster_rcnn_nas': DetectionModel(
'faster_rcnn_nas',
'http://download.tensorflow.org/models/object_detection/faster_rcnn_nas_coco_2018_01_28.tar.gz',
'faster_rcnn_nas_coco_2018_01_28',
),
'mask_rcnn_resnet50_atrous_coco': DetectionModel(
'mask_rcnn_resnet50_atrous_coco',
'http://download.tensorflow.org/models/object_detection/mask_rcnn_resnet50_atrous_coco_2018_01_28.tar.gz',
'mask_rcnn_resnet50_atrous_coco_2018_01_28',
)
}

nets = {
'ssd_mobilenet_v1_coco': {
'config_url': 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/samples/configs/ssd_mobilenet_v1_coco.config',
'checkpoint_url': 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2017_11_17.tar.gz',
},
'ssd_mobilenet_v2_coco': {
'config_url': 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config',
'checkpoint_url': 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz',
},
'ssd_inception_v2_coco': {
'config_url': 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/samples/configs/ssd_inception_v2_coco.config',
'checkpoint_url': 'http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz',
},
}

def download_detection_model(model, output_dir='.'):
"""Download a default detection model configuration and checkpoint.
This function downloads a default detection model configuration and
checkpoint. This is only available for a subset of models in the
TensorFlow object detection model zoo that are known to work on Jetson.
The following models are available
ssd_mobilenet_v1_coco
ssd_mobilenet_v2_coco
ssd_inception_v2_coco
:param model: the model name from the above list
:type model: string
:param output_dir: the directory where files are downloaded to
:type output_dir: string
:return config_path: path to the object detection pipeline config file
:rtype string
:return checkpoint_path: path to the checkpoint files prefix containing trained model params
:rtype string
"""
global nets
config_path = ''
checkpoint_path = ''

if not os.path.exists(output_dir):
os.makedirs(output_dir)

modeldir_path = os.path.join(output_dir, model)
if not os.path.exists(modeldir_path):
os.makedirs(modeldir_path)

config_path = os.path.join(output_dir, model + '.config')
if not os.path.isfile(config_path):
subprocess.call(['wget', '--no-check-certificate', nets[model]['config_url'], '-O', config_path])

modeltar_path = os.path.join(output_dir, os.path.basename(nets[model]['checkpoint_url']))
if not os.path.isfile(modeltar_path):
subprocess.call(['wget', '--no-check-certificate', nets[model]['checkpoint_url'], '-O', modeltar_path])

tar_file = tarfile.open(modeltar_path)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'model.ckpt' in file_name:
file.name = file_name
tar_file.extract(file, modeldir_path)

checkpoint_path = os.path.join(modeldir_path, 'model.ckpt')
def get_input_names(model):
return [INPUT_NAME]

return config_path, checkpoint_path

def build_detection_graph(config, checkpoint):
"""Build an object detection model from the TensorFlow model zoo.
def get_output_names(model):
output_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME]
if model == 'mask_rcnn_resnet50_atrous_coco':
output_names.append(MASKS_NAME)
return output_names

This function creates an object detection model, sourced from the
TensorFlow object detection API.

It is necessary to use this function to generate a frozen graph that is
compatible with TensorFlow/TensorRT integration. In addition to generating
a graph that is compatible with TensorFlow's TensorRT package, this
function performs other graph modifications, such as forced device
placement, that improve performance on Jetson. These graph modifications
are tested with a subset of the object detection API and may or may not
work well with models not listed.
The workflow when using this method is:
1. Train model using TensorFlow object detection API
2. Build graph configured for Jetson using this function
3. Optimize the graph output by this method with the TensorRT package in
TensorFlow
4. Execute in regular TensorFlow, or using the high level TFModel class
def download_detection_model(model, output_dir='.'):
"""Downloads a pre-trained object detection model"""
global MODELS

:param config: path to the object detection pipeline config file
:type config: string
:param checkpoint: path to the checkpoint files prefix containing trained model params
:type checkpoint: string
:returns: the configured frozen graph representing object detection model
:rtype: a tensorflow GraphDef
"""
global input_name, output_map
model_name = model

if isinstance(config, str):
with open(config, 'r') as f:
config_str = f.read()
config = TrainEvalPipelineConfig()
text_format.Merge(config_str, config)
model = MODELS[model_name]
subprocess.call(['mkdir', '-p', output_dir])
tar_file = os.path.join(output_dir, os.path.basename(model.url))

config_path = os.path.join(output_dir, model.extract_dir, PIPELINE_CONFIG_NAME)
checkpoint_path = os.path.join(output_dir, model.extract_dir, CHECKPOINT_PREFIX)

tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
if not os.path.exists(os.path.join(output_dir, model.extract_dir)):
subprocess.call(['wget', model.url, '-O', tar_file])
subprocess.call(['tar', '-xzf', tar_file, '-C', output_dir])

with tf.Graph().as_default() as tf_graph:
with tf.Session(config=tf_config) as tf_sess:
# hack fix to handle mobilenet_v2 config bug
subprocess.call(['sed', '-i', '/batch_norm_trainable/d', config_path])

model = model_builder.build(model_config=config.model, is_training=False)
return config_path, checkpoint_path

tf_input = tf.placeholder(tf.float32, [1, None, None, 3], name=input_name)
tf_preprocessed, tf_true_image_shapes = model.preprocess(tf_input)
tf_predictions = model.predict(preprocessed_inputs=tf_preprocessed,
true_image_shapes=tf_true_image_shapes)
tf_postprocessed = model.postprocess(
prediction_dict=tf_predictions,
true_image_shapes=tf_true_image_shapes
)

tf_saver = tf.train.Saver()
tf_saver.restore(save_path=checkpoint, sess=tf_sess)
def build_detection_graph(config, checkpoint,
batch_size=1,
score_threshold=None,
force_nms_cpu=True,
replace_relu6=True,
remove_assert=True,
input_shape=None,
output_dir='.generated_model'):
"""Builds a frozen graph for a pre-trained object detection model"""

config_path = config
checkpoint_path = checkpoint

# parse config from file
config = pipeline_pb2.TrainEvalPipelineConfig()
with open(config_path, 'r') as f:
text_format.Merge(f.read(), config, allow_unknown_extension=True)

# override some config parameters
if config.model.HasField('ssd'):
config.model.ssd.feature_extractor.override_base_feature_extractor_hyperparams = True
if score_threshold is not None:
config.model.ssd.post_processing.batch_non_max_suppression.score_threshold = score_threshold
if input_shape is not None:
config.model.ssd.image_resizer.fixed_shape_resizer.height = input_shape[0]
config.model.ssd.image_resizer.fixed_shape_resizer.width = input_shape[1]
elif config.model.HasField('faster_rcnn'):
if score_threshold is not None:
config.model.faster_rcnn.second_stage_post_processing.score_threshold = score_threshold
if input_shape is not None:
config.model.faster_rcnn.image_resizer.fixed_shape_resizer.height = input_shape[0]
config.model.faster_rcnn.image_resizer.fixed_shape_resizer.width = input_shape[1]

if os.path.isdir(output_dir):
subprocess.call(['rm', '-rf', output_dir])

outputs = {}
for key, op in tf_postprocessed.items():
if key in output_map.keys():
outputs[output_map[key]] = \
tf.identity(op, name=output_map[key])
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True

frozen_graph = tf.graph_util.convert_variables_to_constants(
tf_sess,
tf_sess.graph_def,
output_node_names=list(outputs.keys())
# export inference graph to file (initial)
with tf.Session(config=tf_config) as tf_sess:
with tf.Graph().as_default() as tf_graph:
exporter.export_inference_graph(
'image_tensor',
config,
checkpoint_path,
output_dir,
input_shape=[batch_size, None, None, 3]
)

frozen_graph = convert_relu6(frozen_graph)
# read frozen graph from file
frozen_graph = tf.GraphDef()
with open(os.path.join(output_dir, FROZEN_GRAPH_NAME), 'rb') as f:
frozen_graph.ParseFromString(f.read())

# apply graph modifications
if force_nms_cpu:
frozen_graph = f_force_nms_cpu(frozen_graph)
if replace_relu6:
frozen_graph = f_replace_relu6(frozen_graph)
if remove_assert:
frozen_graph = f_remove_assert(frozen_graph)

remove_op(frozen_graph, 'Assert')
# get input names
# TODO: handle mask_rcnn
input_names = [INPUT_NAME]
output_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME]

# force CPU device placement for NMS ops
for node in frozen_graph.node:
if 'NonMaxSuppression' in node.name:
node.device = '/device:CPU:0'
# remove temporary directory
subprocess.call(['rm', '-rf', output_dir])

return frozen_graph, [input_name], list(outputs.keys())
return frozen_graph, input_names, output_names
16 changes: 16 additions & 0 deletions tf_trt_models/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,19 @@ def remove_op(graph_def, op_name):
matches = [node for node in graph_def.node if node.op == op_name]
for match in matches:
remove_node(graph_def, match)


def force_nms_cpu(frozen_graph):
for node in frozen_graph.node:
if 'NonMaxSuppression' in node.name:
node.device = '/device:CPU:0'
return frozen_graph


def replace_relu6(frozen_graph):
return convert_relu6(frozen_graph)


def remove_assert(frozen_graph):
remove_op(frozen_graph, 'Assert')
return frozen_graph
2 changes: 1 addition & 1 deletion third_party/models
Submodule models updated 642 files

0 comments on commit 9ce6130

Please sign in to comment.