Skip to content

Commit

Permalink
Update evaluate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-ecm authored Aug 12, 2023
1 parent 175bf57 commit cb320e9
Showing 1 changed file with 41 additions and 3 deletions.
44 changes: 41 additions & 3 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import os
import shutil
import numpy as np
import pandas as pd
from timeit import default_timer as timer

import tensorflow as tf
from core.yolov4 import filter_boxes
from tensorflow.python.saved_model import tag_constants
Expand All @@ -22,7 +25,22 @@
flags.DEFINE_float('iou', 0.5, 'iou threshold')
flags.DEFINE_float('score', 0.25, 'score threshold')

label_map = {
0: 1,
1: 2,
2: 3,
3: 4,
4: 0,
5: 5,
}

def main(_argv):
pred_df = []

preprocess_time = 0
inference_time = 0
eval_time = 0

INPUT_SIZE = FLAGS.size
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
CLASSES = utils.read_class_names(cfg.YOLO.CLASSES)
Expand Down Expand Up @@ -56,6 +74,7 @@ def main(_argv):
# for a,b in enumerate(annotation_file):
# print(a,b)
for num, line in enumerate(annotation_file):
start = timer()
# print('..')
annotation = line.strip().split()
image_path = annotation[0]
Expand Down Expand Up @@ -88,7 +107,9 @@ def main(_argv):
image_data = cv2.resize(np.copy(image), (INPUT_SIZE, INPUT_SIZE))
image_data = image_data / 255.
image_data = image_data[np.newaxis, ...].astype(np.float32)

end = timer()
preprocess_time += end - start
start = timer()
if FLAGS.framework == 'tflite':
interpreter.set_tensor(input_details[0]['index'], image_data)
interpreter.invoke()
Expand Down Expand Up @@ -116,12 +137,24 @@ def main(_argv):
)
boxes, scores, classes, valid_detections = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
pred_bbox = [boxes, scores, classes, valid_detections]
end = timer()
inference_time += end - start
# print(pred_bbox)

# if cfg.TEST.DECTECTED_IMAGE_PATH is not None:
# image_result = utils.draw_bbox(np.copy(image), [boxes, scores, classes, valid_detections])
# cv2.imwrite(cfg.TEST.DECTECTED_IMAGE_PATH + image_name, image_result)

for i in range(valid_detections[0]):
curr_pred_row = {
'ImageID':'/content/test/'+image_name,
'LabelName': label_map[int(classes[0][i])],
'Conf':scores[0][i] ,
'XMin':boxes[0][i][1] ,
'XMax':boxes[0][i][3] ,
'YMin':boxes[0][i][0] ,
'YMax':boxes[0][i][2] ,
}
pred_df.append(curr_pred_row)
with open(predict_result_path, 'w') as f:
image_h, image_w, _ = image.shape
for i in range(valid_detections[0]):
Expand All @@ -140,11 +173,16 @@ def main(_argv):
bbox_mess = ' '.join([class_name, score, xmin, ymin, xmax, ymax]) + '\n'
f.write(bbox_mess)
print('\t' + str(bbox_mess).strip())


print(num, num_lines)
pred_df = pd.DataFrame(pred_df)
pred_df.to_csv('/content/predictions.csv',index=False)

print(preprocess_time)
print(inference_time)
if __name__ == '__main__':
try:
app.run(main)
except SystemExit:
pass

0 comments on commit cb320e9

Please sign in to comment.