Skip to content

Commit

Permalink
wrapped export in TF session
Browse files Browse the repository at this point in the history
  • Loading branch information
jaybdub committed Oct 11, 2018
1 parent 0799854 commit 04d9fe8
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions tf_trt_models/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,23 @@ def build_detection_graph(config, checkpoint,
if score_threshold is not None:
config.model.faster_rcnn.second_stage_post_processing.score_threshold = score_threshold

if os.path.isdir(output_dir):
subprocess.call(['rm', '-rf', output_dir])

tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True

# export inference graph to file (initial)
with tf.Graph().as_default() as tf_graph:
exporter.export_inference_graph(
'image_tensor',
config,
checkpoint_path,
output_dir,
input_shape=[batch_size, None, None, 3],
write_inference_graph=False
)
with tf.Session(config=tf_config) as tf_sess:
with tf.Graph().as_default() as tf_graph:
exporter.export_inference_graph(
'image_tensor',
config,
checkpoint_path,
output_dir,
input_shape=[batch_size, None, None, 3],
write_inference_graph=False
)

# read frozen graph from file
frozen_graph = tf.GraphDef()
Expand Down

0 comments on commit 04d9fe8

Please sign in to comment.