-
Notifications
You must be signed in to change notification settings - Fork 267
/
utils.py
97 lines (83 loc) · 3.51 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
facies_colors = ['#F4D03F', '#F5B041', '#DC7633','#A569BD',
'#000000', '#000080', '#2E86C1', '#AED6F1', '#196F3D']
facies_labels = ['SS', 'CSiS', 'FSiS', 'SiSh', 'MS',
'WS', 'D','PS', 'BS']
def display_cm(cm, labels, hide_zeros=False, display_metrics=False):
"""Display confusion matrix with labels, along with
metrics such as Recall, Precision and F1 score.
Based on Zach Guo's print_cm gist at
https://gist.github.com/zachguo/10296432
"""
precision = np.diagonal(cm)/cm.sum(axis=0).astype('float')
recall = np.diagonal(cm)/cm.sum(axis=1).astype('float')
F1 = 2 * (precision * recall) / (precision + recall)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
F1[np.isnan(F1)] = 0
total_precision = np.sum(precision * cm.sum(axis=1)) / cm.sum(axis=(0,1))
total_recall = np.sum(recall * cm.sum(axis=1)) / cm.sum(axis=(0,1))
total_F1 = np.sum(F1 * cm.sum(axis=1)) / cm.sum(axis=(0,1))
#print total_precision
columnwidth = max([len(x) for x in labels]+[5]) # 5 is value length
empty_cell = " " * columnwidth
# Print header
print(" " + " Pred", end=' ')
for label in labels:
print("%{0}s".format(columnwidth) % label, end=' ')
print("%{0}s".format(columnwidth) % 'Total')
print(" " + " True")
# Print rows
for i, label1 in enumerate(labels):
print(" %{0}s".format(columnwidth) % label1, end=' ')
for j in range(len(labels)):
cell = "%{0}d".format(columnwidth) % cm[i, j]
if hide_zeros:
cell = cell if float(cm[i, j]) != 0 else empty_cell
print(cell, end=' ')
print("%{0}d".format(columnwidth) % sum(cm[i,:]))
if display_metrics:
print()
print("Precision", end=' ')
for j in range(len(labels)):
cell = "%{0}.2f".format(columnwidth) % precision[j]
print(cell, end=' ')
print("%{0}.2f".format(columnwidth) % total_precision)
print(" Recall", end=' ')
for j in range(len(labels)):
cell = "%{0}.2f".format(columnwidth) % recall[j]
print(cell, end=' ')
print("%{0}.2f".format(columnwidth) % total_recall)
print(" F1", end=' ')
for j in range(len(labels)):
cell = "%{0}.2f".format(columnwidth) % F1[j]
print(cell, end=' ')
print("%{0}.2f".format(columnwidth) % total_F1)
def display_adj_cm(
cm, labels, adjacent_facies, hide_zeros=False, display_metrics=False):
"""This function displays a confusion matrix that counts
adjacent facies as correct.
"""
adj_cm = np.copy(cm)
for i in np.arange(0,cm.shape[0]):
for j in adjacent_facies[i]:
adj_cm[i][i] += adj_cm[i][j]
adj_cm[i][j] = 0.0
display_cm(adj_cm, labels, hide_zeros,
display_metrics)
def accuracy(conf):
total_correct = 0.
nb_classes = conf.shape[0]
for i in np.arange(0,nb_classes):
total_correct += conf[i][i]
acc = total_correct/sum(sum(conf))
return acc
def accuracy_adjacent(conf):
adjacent_facies = np.array([[1], [0,2], [1], [4], [3,5], [4,6,7], [5,7], [5,6,8], [6,7]])
nb_classes = conf.shape[0]
total_correct = 0.
for i in np.arange(0,nb_classes):
total_correct += conf[i][i]
for j in adjacent_facies[i]:
total_correct += conf[i][j]
return total_correct / sum(sum(conf))