forked from NVlabs/DREAM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoks_plots.py
181 lines (132 loc) · 5.27 KB
/
oks_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
# This work is licensed under the NVIDIA Source Code License - Non-commercial. Full
# text can be found in LICENSE.md
import pandas as pd
import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use("seaborn-whitegrid")
# Example of running the script
# python oks_plots.py --data all_dataset_keypoints.csv all_dataset_keypoints.csv --labels 1 2
# pythonw oks_plots.py --data deep-arm-cal-paper/data/dope/3cam_real_keypoints.csv deep-arm-cal-paper/data/dream_hg/3cam_real_keypoints.csv deep-arm-cal-paper/data/dream_hg_deconv/3cam_real_keypoints.csv deep-arm-cal-paper/data/resimple/3cam_real_keypoints.csv --labels DOPE DREAM AE resnet
parser = argparse.ArgumentParser(description="OKS for DREAM")
parser.add_argument(
"--data", nargs="+", default="[all_dataset_keypoints.csv]", help="list of csv files"
)
parser.add_argument(
"--labels",
nargs="+",
default=None,
help="names for each dataset to be added as label",
)
parser.add_argument("--styles", nargs="+", default=None, help="")
parser.add_argument("--colours", nargs="+", default=None, help="")
parser.add_argument("--pixel", default=20)
parser.add_argument("--output", default="output.pdf")
parser.add_argument("--show", default=False, action="store_true")
parser.add_argument("--title", default=None)
args = parser.parse_args()
print(args)
fig = plt.figure()
ax = plt.axes()
handles = []
for i_csv, csv_file in enumerate(args.data):
print(csv_file)
if csv_file == "666":
plt.plot([], [], " ", label=args.labels[i_csv].replace("_", " "))
continue
name_csv = csv_file.replace(".csv", "")
df = pd.read_csv(csv_file)
# PCK percentage of correct keypoints
all_dist = []
all_pred = []
all_gt = []
for i in range(7):
# Compute all the distances between keypoints - Implementing them does not work well
fpred = []
fgt = []
gt = df[[f"kp{i}x_gt", f"kp{i}y_gt"]].values.tolist()
pred = df[[f"kp{i}x", f"kp{i}y"]].values.tolist()
all_gt.append(gt)
all_pred.append(pred)
for i_entry, entry in enumerate(gt):
if entry[0] > 0 and entry[0] < 640 and entry[1] > 0 and entry[1] < 480:
fgt.append([float(entry[0]), float(entry[1])])
fpred.append([float(pred[i_entry][0]), float(pred[i_entry][1])])
pred = np.array(fpred)
gt = np.array(fgt)
values = np.linalg.norm(gt - pred, axis=1)
# print(pair.shape)
# add them to a single list
all_dist += values.tolist()
all_dist = np.array(all_dist)
# print(len(all_dist))
# all_dist = all_dist[np.where(all_dist<1000)]
# all_dist = all_dist[np.where(all_dist<1000)]
# print(all_dist[:10])
# print(all_dist.shape)
print("detected", len(all_dist))
pck_values = np.arange(0, int(args.pixel), 0.01)
y_values = []
for value in pck_values:
size_good = len(np.where(all_dist < value)[0]) / len(all_dist)
y_values.append(size_good)
# print(value,size_good)
auc = np.trapz(y_values, dx=0.01) / float(args.pixel)
print("auc", auc)
all_dist = all_dist[np.where(all_dist < 1000)]
print("mean", np.mean(all_dist))
print("median", np.median(all_dist))
print("std", np.std(all_dist))
# TODO: consolidate above calculations with pnp_metrics
from dream import analysis as dream_analysis
dim = np.array(all_pred).shape
temp_pred = np.reshape(np.array(all_pred), (dim[0] * dim[1], dim[2]))
temp_gt = np.reshape(np.array(all_gt), (dim[0] * dim[1], dim[2]))
kp_metrics = dream_analysis.keypoint_metrics(temp_pred, temp_gt, (640, 480))
assert kp_metrics["l2_error_mean_px"] == np.mean(all_dist)
assert kp_metrics["l2_error_median_px"] == np.median(all_dist)
assert kp_metrics["l2_error_std_px"] == np.std(all_dist)
assert np.abs(auc - kp_metrics["l2_error_auc"]) < 1.0e-15
# plot
try:
label = args.labels[i_csv]
except:
label = name_csv
cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
try:
colour = cycle[int(args.colours[i_csv])]
except:
colour = ""
try:
style = args.styles[i_csv]
if style == "0":
style = "-"
elif style == "1":
style = "--"
elif style == "2":
style = ":"
else:
style = "-"
except:
style = "-"
label = f"{label} ({auc:.3f})"
ax.plot(pck_values, y_values, style, color=colour, label=label)
# from matplotlib.patches import Rectangle
# extra = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)
plt.xlabel("PCK threshold distance (pixels)")
plt.ylabel("Accuracy")
plt.title(args.title)
# ax.legend([extra, handles[0], handles[1], handles[2], extra , handles[3], handles[4], handles[5] ], ("Non-DR",0,1,2 ,"DR",0,1,2),loc = "lower right")
ax.legend(loc="lower right", frameon=True, fancybox=True, framealpha=0.8)
legend = ax.get_legend()
for i, t in enumerate(legend.get_texts()):
if args.data[i] == "666":
t.set_ha("left") # ha is alias for horizontalalignment
t.set_position((-30, 0))
ax.set_ylim(0, 1)
ax.set_xlim(0, int(args.pixel))
plt.savefig(args.output)
if args.show:
plt.show()