Skip to content

Commit

Permalink
Benchmark (google#235)
Browse files Browse the repository at this point in the history
* refact saved model inference

* add benchmark

* use perf_counter instead of time
  • Loading branch information
fsx950223 authored Apr 15, 2020
1 parent e579bab commit 207797b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 16 deletions.
32 changes: 32 additions & 0 deletions efficientdet/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from PIL import Image
import tensorflow.compat.v1 as tf
from typing import Text, Dict, Any, List, Tuple, Union
from time import perf_counter

import anchors
import dataloader
Expand Down Expand Up @@ -440,6 +441,27 @@ def serve_files(self, image_files: List[Text]):
feed_dict={self.signitures['image_files']: image_files})
return predictions

def benchmark(self, image_arrays):
if not self.sess:
self.build()

# init session
self.sess.run(
self.signitures['prediction'],
feed_dict={self.signitures['image_arrays']: image_arrays})

start = perf_counter()
for i in range(10):
self.sess.run(
self.signitures['prediction'],
feed_dict={self.signitures['image_arrays']: image_arrays})
end = perf_counter()
inference_time = (end-start) / 10

print('Inference time: ', inference_time)
print('FPS: ', 1 / inference_time)


def serve_images(self, image_arrays):
"""Serve a list of image arrays.
Expand All @@ -457,6 +479,16 @@ def serve_images(self, image_arrays):
feed_dict={self.signitures['image_arrays']: image_arrays})
return predictions

def load(self, saved_model_dir):
if not self.sess:
self.sess = tf.Session()
self.signitures = {
'image_files': 'image_files:0',
'image_arrays': 'image_arrays:0',
'prediction': 'detections:0',
}
return tf.saved_model.load(self.sess, ['serve'], saved_model_dir)

def export(self, output_dir):
"""Export a saved model."""
signitures = self.signitures
Expand Down
44 changes: 28 additions & 16 deletions efficientdet/model_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,32 @@ def export_saved_model(self, **kwargs):

def saved_model_inference(self, image_path_pattern, output_dir, **kwargs):
"""Perform inference for the given saved model."""
with tf.Session() as sess:
tf.saved_model.load(sess, ['serve'], self.saved_model_dir)
raw_images = []
image = Image.open(image_path_pattern)
raw_images.append(np.array(image))
detections_bs = sess.run('detections:0', {'image_arrays:0': raw_images})
driver = inference.ServingDriver(
self.model_name,
self.ckpt_path,
enable_ema=self.enable_ema)
for i, detections in enumerate(detections_bs):
print('detections[:10]=', detections[:10])
img = driver.visualize(raw_images[i], detections, **kwargs)
output_image_path = os.path.join(output_dir, str(i) + '.jpg')
Image.fromarray(img).save(output_image_path)
logging.info('writing file to %s', output_image_path)
driver = inference.ServingDriver(
self.model_name,
self.ckpt_path,
enable_ema=self.enable_ema)
driver.load(self.saved_model_dir)
raw_images = []
image = Image.open(image_path_pattern)
raw_images.append(np.array(image))
detections_bs = driver.serve_images(raw_images)
for i, detections in enumerate(detections_bs):
img = driver.visualize(raw_images[i], detections, **kwargs)
output_image_path = os.path.join(output_dir, str(i) + '.jpg')
Image.fromarray(img).save(output_image_path)
logging.info('writing file to %s', output_image_path)

def saved_model_benchmark(self, image_path_pattern):
"""Perform inference for the given saved model."""
driver = inference.ServingDriver(
self.model_name,
self.ckpt_path,
enable_ema=self.enable_ema)
driver.load(self.saved_model_dir)
raw_images = []
image = Image.open(image_path_pattern)
raw_images.append(np.array(image))
driver.benchmark(raw_images)

def build_and_save_model(self):
"""build and save the model into self.logdir."""
Expand Down Expand Up @@ -342,6 +352,8 @@ def run_model(self, runmode, threads=0):
self.freeze_model()
elif runmode == 'ckpt':
self.eval_ckpt()
elif runmode == 'saved_model_benchmark':
self.saved_model_benchmark(FLAGS.input_image)
elif runmode in ('infer', 'saved_model', 'saved_model_infer'):
config_dict = {}
if FLAGS.line_thickness:
Expand Down

0 comments on commit 207797b

Please sign in to comment.