forked from NVIDIA-AI-IOT/tf_trt_models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
158 additions
and
150 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,168 +1,160 @@ | ||
from object_detection.protos.pipeline_pb2 import TrainEvalPipelineConfig | ||
from object_detection.builders import model_builder | ||
from object_detection.protos import pipeline_pb2 | ||
from object_detection import exporter | ||
|
||
import os | ||
import tarfile | ||
import subprocess | ||
|
||
from collections import namedtuple | ||
from google.protobuf import text_format | ||
|
||
import tensorflow as tf | ||
|
||
from .graph_utils import convert_relu6, remove_op | ||
|
||
input_name = 'input' | ||
output_map = { | ||
'detection_scores': 'scores', | ||
'detection_boxes': 'boxes', | ||
'detection_classes': 'classes', | ||
'detection_masks': 'masks' | ||
} | ||
|
||
nets = { | ||
'ssd_mobilenet_v1_coco': { | ||
'config_url': 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/samples/configs/ssd_mobilenet_v1_coco.config', | ||
'checkpoint_url': 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2017_11_17.tar.gz', | ||
}, | ||
'ssd_mobilenet_v2_coco': { | ||
'config_url': 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config', | ||
'checkpoint_url': 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz', | ||
}, | ||
'ssd_inception_v2_coco': { | ||
'config_url': 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/samples/configs/ssd_inception_v2_coco.config', | ||
'checkpoint_url': 'http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz', | ||
}, | ||
from .graph_utils import force_nms_cpu as f_force_nms_cpu | ||
from .graph_utils import replace_relu6 as f_replace_relu6 | ||
from .graph_utils import remove_assert as f_remove_assert | ||
|
||
DetectionModel = namedtuple('DetectionModel', ['name', 'url', 'extract_dir']) | ||
|
||
INPUT_NAME='image_tensor' | ||
BOXES_NAME='detection_boxes' | ||
CLASSES_NAME='detection_classes' | ||
SCORES_NAME='detection_scores' | ||
MASKS_NAME='detection_masks' | ||
NUM_DETECTIONS_NAME='num_detections' | ||
FROZEN_GRAPH_NAME='frozen_inference_graph.pb' | ||
PIPELINE_CONFIG_NAME='pipeline.config' | ||
CHECKPOINT_PREFIX='model.ckpt' | ||
|
||
|
||
MODELS = { | ||
'ssd_mobilenet_v1_coco': DetectionModel( | ||
'ssd_mobilenet_v1_coco', | ||
'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz', | ||
'ssd_mobilenet_v1_coco_2018_01_28', | ||
), | ||
'ssd_mobilenet_v2_coco': DetectionModel( | ||
'ssd_mobilenet_v2_coco', | ||
'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz', | ||
'ssd_mobilenet_v2_coco_2018_03_29', | ||
), | ||
'ssd_inception_v2_coco': DetectionModel( | ||
'ssd_inception_v2_coco', | ||
'http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2018_01_28.tar.gz', | ||
'ssd_inception_v2_coco_2018_01_28', | ||
), | ||
'ssd_resnet_50_fpn_coco': DetectionModel( | ||
'ssd_resnet_50_fpn_coco', | ||
'http://download.tensorflow.org/models/object_detection/ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz', | ||
'ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03', | ||
), | ||
'faster_rcnn_resnet50_coco': DetectionModel( | ||
'faster_rcnn_resnet50_coco', | ||
'http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz', | ||
'faster_rcnn_resnet50_coco_2018_01_28', | ||
), | ||
'faster_rcnn_nas': DetectionModel( | ||
'faster_rcnn_nas', | ||
'http://download.tensorflow.org/models/object_detection/faster_rcnn_nas_coco_2018_01_28.tar.gz', | ||
'faster_rcnn_nas_coco_2018_01_28', | ||
), | ||
'mask_rcnn_resnet50_atrous_coco': DetectionModel( | ||
'mask_rcnn_resnet50_atrous_coco', | ||
'http://download.tensorflow.org/models/object_detection/mask_rcnn_resnet50_atrous_coco_2018_01_28.tar.gz', | ||
'mask_rcnn_resnet50_atrous_coco_2018_01_28', | ||
) | ||
} | ||
|
||
def download_detection_model(model, output_dir='.'): | ||
"""Download a default detection model configuration and checkpoint. | ||
This function downloads a default detection model configuration and | ||
checkpoint. This is only available for a subset of models in the | ||
TensorFlow object detection model zoo that are known to work on Jetson. | ||
The following models are available | ||
ssd_mobilenet_v1_coco | ||
ssd_mobilenet_v2_coco | ||
ssd_inception_v2_coco | ||
:param model: the model name from the above list | ||
:type model: string | ||
:param output_dir: the directory where files are downloaded to | ||
:type output_dir: string | ||
:return config_path: path to the object detection pipeline config file | ||
:rtype string | ||
:return checkpoint_path: path to the checkpoint files prefix containing trained model params | ||
:rtype string | ||
""" | ||
global nets | ||
config_path = '' | ||
checkpoint_path = '' | ||
|
||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir) | ||
|
||
modeldir_path = os.path.join(output_dir, model) | ||
if not os.path.exists(modeldir_path): | ||
os.makedirs(modeldir_path) | ||
|
||
config_path = os.path.join(output_dir, model + '.config') | ||
if not os.path.isfile(config_path): | ||
subprocess.call(['wget', '--no-check-certificate', nets[model]['config_url'], '-O', config_path]) | ||
|
||
modeltar_path = os.path.join(output_dir, os.path.basename(nets[model]['checkpoint_url'])) | ||
if not os.path.isfile(modeltar_path): | ||
subprocess.call(['wget', '--no-check-certificate', nets[model]['checkpoint_url'], '-O', modeltar_path]) | ||
|
||
tar_file = tarfile.open(modeltar_path) | ||
for file in tar_file.getmembers(): | ||
file_name = os.path.basename(file.name) | ||
if 'model.ckpt' in file_name: | ||
file.name = file_name | ||
tar_file.extract(file, modeldir_path) | ||
|
||
checkpoint_path = os.path.join(modeldir_path, 'model.ckpt') | ||
|
||
return config_path, checkpoint_path | ||
|
||
def build_detection_graph(config, checkpoint): | ||
"""Build an object detection model from the TensorFlow model zoo. | ||
This function creates an object detection model, sourced from the | ||
TensorFlow object detection API. | ||
|
||
It is necessary to use this function to generate a frozen graph that is | ||
compatible with TensorFlow/TensorRT integration. In addition to generating | ||
a graph that is compatible with TensorFlow's TensorRT package, this | ||
function performs other graph modifications, such as forced device | ||
placement, that improve performance on Jetson. These graph modifications | ||
are tested with a subset of the object detection API and may or may not | ||
work well with models not listed. | ||
def get_input_names(model): | ||
return [INPUT_NAME] | ||
|
||
The workflow when using this method is: | ||
|
||
1. Train model using TensorFlow object detection API | ||
2. Build graph configured for Jetson using this function | ||
3. Optimize the graph output by this method with the TensorRT package in | ||
TensorFlow | ||
4. Execute in regular TensorFlow, or using the high level TFModel class | ||
def get_output_names(model): | ||
output_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME] | ||
if model == 'mask_rcnn_resnet50_atrous_coco': | ||
output_names.append(MASKS_NAME) | ||
return output_names | ||
|
||
:param config: path to the object detection pipeline config file | ||
:type config: string | ||
:param checkpoint: path to the checkpoint files prefix containing trained model params | ||
:type checkpoint: string | ||
:returns: the configured frozen graph representing object detection model | ||
:rtype: a tensorflow GraphDef | ||
""" | ||
global input_name, output_map | ||
|
||
if isinstance(config, str): | ||
with open(config, 'r') as f: | ||
config_str = f.read() | ||
config = TrainEvalPipelineConfig() | ||
text_format.Merge(config_str, config) | ||
|
||
|
||
tf_config = tf.ConfigProto() | ||
tf_config.gpu_options.allow_growth = True | ||
|
||
with tf.Graph().as_default() as tf_graph: | ||
with tf.Session(config=tf_config) as tf_sess: | ||
|
||
model = model_builder.build(model_config=config.model, is_training=False) | ||
def download_detection_model(model, output_dir='.'): | ||
"""Downloads a pre-trained object detection model""" | ||
global MODELS | ||
|
||
tf_input = tf.placeholder(tf.float32, [1, None, None, 3], name=input_name) | ||
tf_preprocessed, tf_true_image_shapes = model.preprocess(tf_input) | ||
tf_predictions = model.predict(preprocessed_inputs=tf_preprocessed, | ||
true_image_shapes=tf_true_image_shapes) | ||
tf_postprocessed = model.postprocess( | ||
prediction_dict=tf_predictions, | ||
true_image_shapes=tf_true_image_shapes | ||
) | ||
model_name = model | ||
|
||
tf_saver = tf.train.Saver() | ||
tf_saver.restore(save_path=checkpoint, sess=tf_sess) | ||
model = MODELS[model_name] | ||
subprocess.call(['mkdir', '-p', output_dir]) | ||
tar_file = os.path.join(output_dir, os.path.basename(model.url)) | ||
|
||
outputs = {} | ||
for key, op in tf_postprocessed.items(): | ||
if key in output_map.keys(): | ||
outputs[output_map[key]] = \ | ||
tf.identity(op, name=output_map[key]) | ||
config_path = os.path.join(output_dir, model.extract_dir, PIPELINE_CONFIG_NAME) | ||
checkpoint_path = os.path.join(output_dir, model.extract_dir, CHECKPOINT_PREFIX) | ||
|
||
frozen_graph = tf.graph_util.convert_variables_to_constants( | ||
tf_sess, | ||
tf_sess.graph_def, | ||
output_node_names=list(outputs.keys()) | ||
) | ||
if not os.path.exists(os.path.join(output_dir, model.extract_dir)): | ||
subprocess.call(['wget', model.url, '-O', tar_file]) | ||
subprocess.call(['tar', '-xzf', tar_file, '-C', output_dir]) | ||
|
||
frozen_graph = convert_relu6(frozen_graph) | ||
# hack fix to handle mobilenet_v2 config bug | ||
subprocess.call(['sed', '-i', '/batch_norm_trainable/d', config_path]) | ||
|
||
remove_op(frozen_graph, 'Assert') | ||
return config_path, checkpoint_path | ||
|
||
# force CPU device placement for NMS ops | ||
for node in frozen_graph.node: | ||
if 'NonMaxSuppression' in node.name: | ||
node.device = '/device:CPU:0' | ||
|
||
return frozen_graph, [input_name], list(outputs.keys()) | ||
def build_detection_graph(config, checkpoint, | ||
batch_size=1, | ||
score_threshold=None, | ||
force_nms_cpu=True, | ||
replace_relu6=True, | ||
remove_assert=True, | ||
output_dir='.generated_model'): | ||
"""Builds a frozen graph for a pre-trained object detection model""" | ||
|
||
config_path = config | ||
checkpoint_path = checkpoint | ||
|
||
# parse config from file | ||
config = pipeline_pb2.TrainEvalPipelineConfig() | ||
with open(config_path, 'r') as f: | ||
text_format.Merge(f.read(), config, allow_unknown_extension=True) | ||
|
||
# override some config parameters | ||
if config.model.HasField('ssd'): | ||
config.model.ssd.feature_extractor.override_base_feature_extractor_hyperparams = True | ||
if score_threshold is not None: | ||
config.model.ssd.post_processing.batch_non_max_suppression.score_threshold = score_threshold | ||
elif config.model.HasField('faster_rcnn'): | ||
if score_threshold is not None: | ||
config.model.faster_rcnn.second_stage_post_processing.score_threshold = score_threshold | ||
|
||
# export inference graph to file (initial) | ||
with tf.Graph().as_default() as tf_graph: | ||
exporter.export_inference_graph( | ||
'image_tensor', | ||
config, | ||
checkpoint_path, | ||
output_dir, | ||
input_shape=[batch_size, None, None, 3], | ||
write_inference_graph=False | ||
) | ||
|
||
# read frozen graph from file | ||
frozen_graph = tf.GraphDef() | ||
with open(os.path.join(output_dir, FROZEN_GRAPH_NAME), 'rb') as f: | ||
frozen_graph.ParseFromString(f.read()) | ||
|
||
# apply graph modifications | ||
if force_nms_cpu: | ||
frozen_graph = f_force_nms_cpu(frozen_graph) | ||
if replace_relu6: | ||
frozen_graph = f_replace_relu6(frozen_graph) | ||
if remove_assert: | ||
frozen_graph = f_remove_assert(frozen_graph) | ||
|
||
# get input names | ||
# TODO: handle mask_rcnn | ||
input_names = [INPUT_NAME] | ||
output_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME] | ||
|
||
# remove temporary directory | ||
subprocess.call(['rm', '-rf', output_dir]) | ||
|
||
return frozen_graph, input_names, output_names |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters