Skip to content

Commit

Permalink
write numbers on top of bars
Browse files Browse the repository at this point in the history
  • Loading branch information
Cartucho committed Apr 17, 2018
1 parent a115590 commit 0c3431e
Showing 1 changed file with 44 additions and 18 deletions.
62 changes: 44 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,29 +149,55 @@ def draw_text_in_image(img, text, pos, color, line_width):
"""
Draw plot using Matplotlib
"""
def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, output_path, to_show, plot_color, stack_bar):
def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, output_path, to_show, plot_color, true_p_bar):
# set window title
fig = plt.gcf() # gcf - get current figure
fig.canvas.set_window_title(window_title)
# sort the dictionary by decreasing value (reverse=True), into a list of tuples
sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1), reverse=True)
# unpacking the list of tuples into two lists
sorted_keys, sorted_values = zip(*sorted_dic_by_value)
plt.bar(range(n_classes), sorted_values, align='center', color=plot_color, label='True Predictions')
# special case to draw in red false detections
if stack_bar != "":
stack_list = []
# special case to draw in (green, true predictions) & (red, false predictions)
if true_p_bar != "":
fp_sorted = []
tp_sorted = []
for key in sorted_keys:
stack_list.append(dictionary[key] - stack_bar[key])
plt.bar(range(n_classes), stack_list, align='center', color='crimson', label='False Predictions')
fp_sorted.append(dictionary[key] - true_p_bar[key])
tp_sorted.append(true_p_bar[key])
plt.bar(range(n_classes), fp_sorted, align='center', color='crimson', label='False Predictions')
plt.bar(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Predictions', bottom=fp_sorted)
# add legend
plt.legend(loc='best')
# write classes in x axis "vertically"
if n_classes <= 30:
plt.xticks(range(n_classes), sorted_keys, rotation='vertical', fontsize=12)
# write number on top of bar
for i, val in enumerate(sorted_values):
if i == 0:
# re-scale plot to write numbers on top
axes = plt.gca() # get current axes
axes.set_ylim([0,sorted_values[0]*1.25]) # this only works since the first value is the largest
fp_val = fp_sorted[i]
y_fp_offset = len(str(fp_val) * 7) # 'hack' to find top offset
tp_val = tp_sorted[i]
y_tp_offset = y_fp_offset + len(str(tp_val)) * 7 + 2
# (- 0.4) because the bar's default width=0.8
plt.annotate(str(fp_val), xy=(i - 0.4,val), color='crimson', fontweight='bold', rotation='vertical', xytext=(0,y_fp_offset), textcoords='offset points')
plt.annotate(str(tp_val), xy=(i - 0.4,val), color='forestgreen', fontweight='bold', rotation='vertical', xytext=(0,y_tp_offset), textcoords='offset points')
else:
# if there are more than 30 classes we need to use the default font size
# otherwise the labels start to overlap each other
plt.xticks(range(n_classes), sorted_keys, rotation='vertical')
# set window title
fig = plt.gcf() # gcf - get current figure
fig.canvas.set_window_title(window_title)
plt.bar(range(n_classes), sorted_values, align='center', color=plot_color)
# write number on top of bar
for i, val in enumerate(sorted_values):
if i == 0:
# re-scale plot to write numbers on top
axes = plt.gca() # get current axes
axes.set_ylim([0,sorted_values[0]*1.25]) # this only works since the first value is the largest
str_val = str(val)
if val <= 1.0:
str_val = "{0:.2f}".format(val)
y_offset = len(str_val) * 7 # 'hack' to find top offset
# (- 0.4) because the bar's default width=0.8
plt.annotate(str_val, xy=(i - 0.4,val), color=plot_color, fontweight='bold', rotation='vertical', xytext=(0,y_offset), textcoords='offset points')

# write classes in x axis "vertically"
plt.xticks(range(n_classes), sorted_keys, rotation='vertical')
# set plot title
plt.title(plot_title, fontsize=14)
# set axis titles
Expand Down Expand Up @@ -612,7 +638,7 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
output_path = results_files_path + "/Predicted Objects Info.png"
to_show = False
plot_color = 'forestgreen'
stack_bar = count_true_positives
true_p_bar = count_true_positives
draw_plot_func(
pred_counter_per_class,
len(pred_counter_per_class),
Expand All @@ -622,7 +648,7 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, y_label, out
output_path,
to_show,
plot_color,
stack_bar
true_p_bar
)

"""
Expand Down

0 comments on commit 0c3431e

Please sign in to comment.