forked from machrisaa/tensorflow-vgg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_vgg16.py
executable file
·34 lines (26 loc) · 884 Bytes
/
test_vgg16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
import tensorflow as tf
import vgg16
import utils
import time
start = time.time()
def tick():
print(time.time() - start)
filenames = ["./test_data/tiger.jpeg", "./test_data/puzzle.jpeg", "./test_data/tiger.jpeg", "./test_data/puzzle.jpeg"]
images = [utils.load_image(f) for f in filenames]
batches = [im.reshape((1, 224, 224, 3)) for im in images]
batch = np.concatenate(batches, 0)
tick()
with tf.Session(
config=tf.ConfigProto(gpu_options=(tf.GPUOptions(per_process_gpu_memory_fraction=0.7)))) as sess:
images = tf.placeholder("float", [len(filenames), 224, 224, 3])
feed_dict = {images: batch}
tick()
vgg = vgg16.Vgg16()
with tf.name_scope("content_vgg"):
vgg.build(images)
tick()
probs = sess.run(vgg.prob, feed_dict=feed_dict)
tick()
for pr in probs:
utils.print_prob(pr, './synset.txt')