-
Notifications
You must be signed in to change notification settings - Fork 29
/
play_styles.py
83 lines (67 loc) · 2.34 KB
/
play_styles.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#import seaborn as sns
import umap
import pickle, sys, math
from collections import defaultdict
#sns.set(style='white', context='poster', rc={'figure.figsize':(14,10)})
with open(sys.argv[1],'rb') as f:
data = pickle.load(f)
styles=data['styles']
if len(styles.shape)==4:
styles=styles[:,:,0,0]
authors=data['authors']
inter_author_dist=[]
intra_author_dist=[]
#author_matrix=np.empty((len(authors),len(authors)))
author_d=defaultdict(list)
def dist(vA,vB):
#return np.dot(vA, vB) / (np.sqrt(np.dot(vA,vA)) * np.sqrt(np.dot(vB,vB)))
#return np.abs(vA-vB).sum()
return math.sqrt(np.power(vA-vB,2).sum())
for i in range(len(authors)):
for j in range(i+1,len(authors)):
d=dist(styles[i],styles[j])
if authors[j]==authors[i]:
inter_author_dist.append( d)
else:
intra_author_dist.append(d)
author_d[(i,j)].append(d)
print('inter dist mean: {}, stddev: {}'.format(np.mean(inter_author_dist),np.std(inter_author_dist)))
print('intra dist mean: {}, stddev: {}'.format(np.mean(intra_author_dist),np.std(intra_author_dist)))
exit()
author_mean=np.zeros((len(authors),len(authors)))
author_std=np.zeros((len(authors),len(authors)))
for pair,l in author_d.items():
a1,a2=pair
m=np.mean(l)
s=np.std(l)
author_mean[a1,a2]=m
author_mean[a2,a1]=m
author_std[a1,a2]=s
author_std[a2,a1]=s
cmap=plt.cm.Blues
fig, ax = plt.subplots()
im = ax.imshow(author_mean, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(author_mean.shape[1]),
yticks=np.arange(author_mean.shape[0]),
# ... and label them with the respective list entries
#xticklabels=classes, yticklabels=classes,
title='mean',
#ylabel='True label',
#xlabel='Predicted label')
)
fig, ax = plt.subplots()
im = ax.imshow(author_std, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(author_std.shape[1]),
yticks=np.arange(author_std.shape[0]),
# ... and label them with the respective list entries
#xticklabels=classes, yticklabels=classes,
title='std',
#ylabel='True label',
#xlabel='Predicted label')
)
plt.show()