Skip to content

Commit

Permalink
add true positives info to plot
Browse files Browse the repository at this point in the history
  • Loading branch information
Cartucho committed Apr 17, 2018
1 parent aff0929 commit a115590
Showing 1 changed file with 150 additions and 79 deletions.
229 changes: 150 additions & 79 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,36 +149,43 @@ def draw_text_in_image(img, text, pos, color, line_width):
"""
Draw plot using Matplotlib
"""
def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, output_path, to_show, plot_color):
# sort the dictionary by decreasing value (reverse=True), into a list of tuples
sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1), reverse=True)
# unpacking the list of tuples into two lists
sorted_keys, sorted_values = zip(*sorted_dic_by_value)
plt.bar(range(n_classes), sorted_values, align='center', color=plot_color)
# write classes in x axis "vertically"
if n_classes <= 30:
plt.xticks(range(n_classes), sorted_keys, rotation='vertical', fontsize=12)
else:
# if there are more than 30 classes we need to use the default font size
# otherwise the labels start to overlap each other
plt.xticks(range(n_classes), sorted_keys, rotation='vertical')
# set window title
fig = plt.gcf() # gcf - get current figure
fig.canvas.set_window_title(window_title)
# set plot title
plt.title(plot_title, fontsize=14)
# set axis titles
# plt.xlabel('classes')
plt.ylabel(y_label, fontsize='large')
# adjust size of window
fig.tight_layout()
# save the plot
fig.savefig(output_path)
# show image
if to_show:
plt.show()
# clear the plot
plt.clf()
def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, output_path, to_show, plot_color, stack_bar):
# sort the dictionary by decreasing value (reverse=True), into a list of tuples
sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1), reverse=True)
# unpacking the list of tuples into two lists
sorted_keys, sorted_values = zip(*sorted_dic_by_value)
plt.bar(range(n_classes), sorted_values, align='center', color=plot_color, label='True Predictions')
# special case to draw in red false detections
if stack_bar != "":
stack_list = []
for key in sorted_keys:
stack_list.append(dictionary[key] - stack_bar[key])
plt.bar(range(n_classes), stack_list, align='center', color='crimson', label='False Predictions')
plt.legend(loc='best')
# write classes in x axis "vertically"
if n_classes <= 30:
plt.xticks(range(n_classes), sorted_keys, rotation='vertical', fontsize=12)
else:
# if there are more than 30 classes we need to use the default font size
# otherwise the labels start to overlap each other
plt.xticks(range(n_classes), sorted_keys, rotation='vertical')
# set window title
fig = plt.gcf() # gcf - get current figure
fig.canvas.set_window_title(window_title)
# set plot title
plt.title(plot_title, fontsize=14)
# set axis titles
# plt.xlabel('classes')
plt.ylabel(y_label, fontsize='large')
# adjust size of window
fig.tight_layout()
# save the plot
fig.savefig(output_path)
# show image
if to_show:
plt.show()
# clear the plot
plt.clf()

"""
Create a "tmp_files/" and "results/" directory
Expand All @@ -200,15 +207,13 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
"""
Ground-Truth
Load each of the ground-truth files into a temporary ".json" file.
Create a list of all the class names present in the ground-truth (unique_classes).
Create a list of all the class names present in the ground-truth (gt_classes).
"""
# get a list with the ground-truth files
ground_truth_files_list = glob.glob('ground-truth/*.txt')
ground_truth_files_list.sort()
gt_counter_per_class = {}

unique_classes = set([])
# dictionary with counter per class
gt_counter_per_class = {}

for txt_file in ground_truth_files_list:
#print(txt_file)
Expand Down Expand Up @@ -241,15 +246,15 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
else:
# if class didn't exist yet
gt_counter_per_class[class_name] = 1
unique_classes.add(class_name)
# dump bounding_boxes into a ".json" file
with open(tmp_files_path + "/" + file_id + "_ground_truth.json", 'w') as outfile:
json.dump(bounding_boxes, outfile)

gt_classes = list(gt_counter_per_class.keys())
# let's sort the classes alphabetically
unique_classes = sorted(unique_classes)
n_classes = len(unique_classes)
#print(unique_classes)
gt_classes = sorted(gt_classes)
n_classes = len(gt_classes)
#print(gt_classes)
#print(gt_counter_per_class)

"""
Expand All @@ -270,7 +275,7 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
if len(specific_iou_classes) != len(iou_list):
error('Error, missing arguments. Flag usage:' + error_msg)
for tmp_class in specific_iou_classes:
if tmp_class not in unique_classes:
if tmp_class not in gt_classes:
error('Error, unknown class \"' + tmp_class + '\". Flag usage:' + error_msg)
for num in iou_list:
if not is_float_between_0_and_1(num):
Expand All @@ -287,7 +292,17 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
output_path = results_files_path + "/Ground-Truth Info.png"
to_show = False
plot_color = 'forestgreen'
draw_plot_func(gt_counter_per_class, n_classes, window_title, plot_title, y_label, output_path, to_show, plot_color)
draw_plot_func(
gt_counter_per_class,
n_classes,
window_title,
plot_title,
y_label,
output_path,
to_show,
plot_color,
'',
)

"""
Write number of ground-truth objects per class to results.txt
Expand All @@ -304,13 +319,17 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
# get a list with the predicted files
predicted_files_list = glob.glob('predicted/*.txt')
predicted_files_list.sort()
pred_counter_per_class = {}

for class_name in unique_classes:
for class_index, class_name in enumerate(gt_classes):
bounding_boxes = []
pred_counter_per_class[class_name] = 0
for txt_file in predicted_files_list:
#print txt_file
#print(txt_file)
# the first time it checks if all the corresponding ground-truth files exist
file_id = txt_file.split(".txt",1)[0]
file_id = os.path.basename(os.path.normpath(file_id))
if class_index == 0:
if not os.path.exists('ground-truth/' + file_id + ".txt"):
error("Error. File not found: ground-truth/" + file_id + ".txt")
lines = file_lines_to_list(txt_file)
for line in lines:
try:
Expand All @@ -322,47 +341,14 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
error(error_msg)
if tmp_class_name == class_name:
#print("match")
file_id = txt_file.split(".txt",1)[0]
file_id = os.path.basename(os.path.normpath(file_id))
bbox = left + " " + top + " " + right + " " +bottom
bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
#print(bounding_boxes)
# count that object
if class_name in pred_counter_per_class:
pred_counter_per_class[class_name] += 1
else:
# if class didn't exist yet
pred_counter_per_class[class_name] = 1
# sort predictions by decreasing confidence
bounding_boxes.sort(key=lambda x:x['confidence'], reverse=True)
with open(tmp_files_path + "/" + class_name + "_predictions.json", 'w') as outfile:
json.dump(bounding_boxes, outfile)

"""
Plot the total number of occurences of each class in the "predicted" folder
"""
if draw_plot:
window_title = "Predicted Objects Info"
# Plot title
plot_title = "Predicted Objects\n"
plot_title += "(" + str(len(predicted_files_list)) + " files and "
count_non_zero_values_in_dictionary = sum(x > 0 for x in list(pred_counter_per_class.values()))
plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
# end Plot title
y_label = "Number of objects per class"
output_path = results_files_path + "/Predicted Objects Info.png"
to_show = False
plot_color = 'forestgreen'
draw_plot_func(pred_counter_per_class, n_classes, window_title, plot_title, y_label, output_path, to_show, plot_color)

"""
Write number of predicted objects per class to results.txt
"""
with open(results_files_path + "/results.txt", 'a') as results_file:
results_file.write("\n# Number of predicted objects per class\n")
for class_name in sorted(pred_counter_per_class):
results_file.write(class_name + ": " + str(pred_counter_per_class[class_name]) + "\n")

"""
Calculate the AP for each class
"""
Expand All @@ -371,7 +357,9 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
# open file to store the results
with open(results_files_path + "/results.txt", 'a') as results_file:
results_file.write("\n# AP and precision/recall per class\n")
for class_index, class_name in enumerate(unique_classes):
count_true_positives = {}
for class_index, class_name in enumerate(gt_classes):
count_true_positives[class_name] = 0
"""
Load predictions of that class
"""
Expand Down Expand Up @@ -440,6 +428,7 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
# true positive
tp[idx] = 1
gt_match["used"] = True
count_true_positives[class_name] += 1
# update the ".json" file
with open(gt_file, 'w') as f:
f.write(json.dumps(ground_truth_data))
Expand Down Expand Up @@ -576,6 +565,78 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
# remove the tmp_files directory
shutil.rmtree(tmp_files_path)

"""
Count total of Predictions
"""
# iterate through all the files
pred_counter_per_class = {}
#all_classes_predicted_files = set([])
for txt_file in predicted_files_list:
# get lines to list
lines_list = file_lines_to_list(txt_file)
for line in lines_list:
class_name = line.split()[0]
# check if class is in the ignore list, if yes skip
if class_name in args.ignore:
continue
# count that object
if class_name in pred_counter_per_class:
pred_counter_per_class[class_name] += 1
else:
# if class didn't exist yet
pred_counter_per_class[class_name] = 1
#print(pred_counter_per_class)
pred_classes = list(pred_counter_per_class.keys())

"""
Finish counting true positives
"""
for class_name in pred_classes:
# if class exists in predictions but not in ground-truth then there are no true positives in that class
if class_name not in gt_classes:
count_true_positives[class_name] = 0
#print(count_true_positives)

"""
Plot the total number of occurences of each class in the "predicted" folder
"""
if draw_plot:
window_title = "Predicted Objects Info"
# Plot title
plot_title = "Predicted Objects\n"
plot_title += "(" + str(len(predicted_files_list)) + " files and "
count_non_zero_values_in_dictionary = sum(x > 0 for x in list(pred_counter_per_class.keys()))
plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
# end Plot title
y_label = "Number of objects per class"
output_path = results_files_path + "/Predicted Objects Info.png"
to_show = False
plot_color = 'forestgreen'
stack_bar = count_true_positives
draw_plot_func(
pred_counter_per_class,
len(pred_counter_per_class),
window_title,
plot_title,
y_label,
output_path,
to_show,
plot_color,
stack_bar
)

"""
Write number of predicted objects per class to results.txt
"""
with open(results_files_path + "/results.txt", 'a') as results_file:
results_file.write("\n# Number of predicted objects per class\n")
for class_name in sorted(pred_classes):
n_pred = pred_counter_per_class[class_name]
text = class_name + ": " + str(n_pred)
text += " (tp:" + str(count_true_positives[class_name]) + ""
text += ", fp:" + str(n_pred - count_true_positives[class_name]) + ")\n"
results_file.write(text)

"""
Draw mAP plot (Show AP's of all classes in decreasing order)
"""
Expand All @@ -586,4 +647,14 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
output_path = results_files_path + "/mAP.png"
to_show = True
plot_color = 'royalblue'
draw_plot_func(ap_dictionary, n_classes, window_title, plot_title, y_label, output_path, to_show, plot_color)
draw_plot_func(
ap_dictionary,
n_classes,
window_title,
plot_title,
y_label,
output_path,
to_show,
plot_color,
""
)

0 comments on commit a115590

Please sign in to comment.