forked from hunglc007/tensorflow-yolov4-tflite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
detectvideo.py
123 lines (111 loc) · 4.85 KB
/
detectvideo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import time
from absl import app, flags, logging
from absl.flags import FLAGS
import core.utils as utils
from core.yolov4 import YOLOv4, YOLOv3, YOLOv3_tiny, decode
from PIL import Image
from core.config import cfg
import cv2
import numpy as np
import tensorflow as tf
flags.DEFINE_string('framework', 'tf', '(tf, tflite')
flags.DEFINE_string('weights', './data/yolov4.weights',
'path to weights file')
flags.DEFINE_integer('size', 608, 'resize images to')
flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
flags.DEFINE_string('video', './data/road.avi', 'path to input video')
def main(_argv):
if FLAGS.tiny:
STRIDES = np.array(cfg.YOLO.STRIDES_TINY)
ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS_TINY, FLAGS.tiny)
else:
STRIDES = np.array(cfg.YOLO.STRIDES)
if FLAGS.model == 'yolov4':
ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS, FLAGS.tiny)
else:
ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS_V3, FLAGS.tiny)
NUM_CLASS = len(utils.read_class_names(cfg.YOLO.CLASSES))
XYSCALE = cfg.YOLO.XYSCALE
input_size = FLAGS.size
video_path = FLAGS.video
print("Video from: ", video_path )
vid = cv2.VideoCapture(video_path)
if FLAGS.framework == 'tf':
input_layer = tf.keras.layers.Input([input_size, input_size, 3])
if FLAGS.tiny:
feature_maps = YOLOv3_tiny(input_layer, NUM_CLASS)
bbox_tensors = []
for i, fm in enumerate(feature_maps):
bbox_tensor = decode(fm, NUM_CLASS, i)
bbox_tensors.append(bbox_tensor)
model = tf.keras.Model(input_layer, bbox_tensors)
utils.load_weights_tiny(model, FLAGS.weights)
else:
if FLAGS.model == 'yolov3':
feature_maps = YOLOv3(input_layer, NUM_CLASS)
bbox_tensors = []
for i, fm in enumerate(feature_maps):
bbox_tensor = decode(fm, NUM_CLASS, i)
bbox_tensors.append(bbox_tensor)
model = tf.keras.Model(input_layer, bbox_tensors)
utils.load_weights_v3(model, FLAGS.weights)
elif FLAGS.model == 'yolov4':
feature_maps = YOLOv4(input_layer, NUM_CLASS)
bbox_tensors = []
for i, fm in enumerate(feature_maps):
bbox_tensor = decode(fm, NUM_CLASS, i)
bbox_tensors.append(bbox_tensor)
model = tf.keras.Model(input_layer, bbox_tensors)
if FLAGS.weights.split(".")[len(FLAGS.weights.split(".")) - 1] == "weights":
utils.load_weights(model, FLAGS.weights)
else:
model.load_weights(FLAGS.weights).expect_partial()
model.summary()
else:
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
while True:
return_value, frame = vid.read()
if return_value:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image = Image.fromarray(frame)
else:
raise ValueError("No image! Try with another video format")
frame_size = frame.shape[:2]
image_data = utils.image_preprocess(np.copy(frame), [input_size, input_size])
image_data = image_data[np.newaxis, ...].astype(np.float32)
prev_time = time.time()
if FLAGS.framework == 'tf':
pred_bbox = model.predict(image_data)
else:
interpreter.set_tensor(input_details[0]['index'], image_data)
interpreter.invoke()
pred_bbox = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
if FLAGS.model == 'yolov4':
pred_bbox = utils.postprocess_bbbox(pred_bbox, ANCHORS, STRIDES, XYSCALE)
else:
pred_bbox = utils.postprocess_bbbox(pred_bbox, ANCHORS, STRIDES)
bboxes = utils.postprocess_boxes(pred_bbox, frame_size, input_size, 0.25)
bboxes = utils.nms(bboxes, 0.213, method='nms')
image = utils.draw_bbox(frame, bboxes)
curr_time = time.time()
exec_time = curr_time - prev_time
result = np.asarray(image)
info = "time: %.2f ms" %(1000*exec_time)
print(info)
cv2.namedWindow("result", cv2.WINDOW_AUTOSIZE)
result = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imshow("result", result)
if cv2.waitKey(1) & 0xFF == ord('q'): break
if __name__ == '__main__':
try:
app.run(main)
except SystemExit:
pass