-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathboxplots.py
67 lines (56 loc) · 2.24 KB
/
boxplots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
if __name__ == '__main__':
weight_as_hline = False
columns = ["dset", "Criterion", "n", "s", "scenario", "stage",
"Train Accuracy"]
df = pd.read_csv("output/log.txt", sep="-| ", names=columns)
# Remove weight criterion (does not depend on n), is added as hline
df = df[~df.Criterion.str.contains("criterion:weight")]
w_acc = {
"moon": 99.6,
"circle": 97.1,
"mult": 91.0,
}
# Clean up dataframe
df.Criterion = df.Criterion.str.replace("criterion:", "")
df.dset = df.dset.str.replace("dataset:", "")
df.Criterion = df.Criterion.apply(lambda x: x[0].upper())
df.n = df.n.str.replace("n:", "").astype(int)
df["n_orig"] = df.n.copy()
stage = "stage:post"
scenario_train = "scenario:train"
df = df[df.stage == stage][df.scenario == scenario_train]
# Map n in dataframe to number of reference samples
n_mapping = {
"moon": 2,
"circle": 2,
"mult": 4,
}
num_reference_samples = [1, 5, 20, 100]
# Define colormap
color_map = {"G": "yellowgreen", "T": "cornflowerblue", "L": "indianred"}
for i, dset in enumerate(["moon", "circle", "mult"]):
df["n"] = df.n_orig // n_mapping[dset]
df_dset = df[df.n.isin(num_reference_samples)]
df_dset = df_dset[df_dset.dset == dset]
fig = plt.figure(figsize=(3.5, 3.5))
plt.subplots_adjust(left=0.19, right=0.99, top=0.93, bottom=0.13)
plt.ylim([33, 100])
if weight_as_hline:
# Weight criterion as horizontal line
plt.axhline(y=w_acc[dset], linestyle="--", color="k", label="W", alpha=0.5)
sns.boxplot(data=df_dset, y="Train Accuracy", x="n", hue="Criterion", showmeans=False, ax=plt.gca(), hue_order=list("GTL"), fliersize=0, palette=color_map)
if i > 0:
plt.yticks([], [])
plt.ylabel(None)
plt.gca().get_legend().remove()
if i == 1:
plt.xlabel("samples used to compute criteria")
else:
plt.xlabel(None)
plt.title(dset)
plt.savefig(f"toy_experiment_train-{dset}.pdf")
plt.savefig(f"toy_experiment_train-{dset}.png")
plt.show()