Skip to content

Commit 37a5e90

Browse files
authored
Merge pull request metafy-social#360 from lucasocarvalhos/add-project
Add Confusion Matrix
2 parents a1dde3b + 2ab7ecf commit 37a5e90

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

scripts/Confusion_Matrix/README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
The function makes a labelled confusion matrix comparing predictions and ground truth labels.
2+
3+
If classes is passed, confusion matrix will be labelled, if not, integer class values will be used.
4+
5+
Args:
6+
7+
* `y_true`: Array of truth labels (must be same shape as y_pred).
8+
* `y_pred`: Array of predicted labels (must be same shape as y_true).
9+
* `classes`: Array of class labels (e.g. string form). If `None`, integer labels are used.
10+
* `figsize`: Size of output figure (default=(10, 10)).
11+
* `text_size`: Size of output figure text (default=15).
12+
* `norm`: normalize values or not (default=False).
13+
* `savefig`: save confusion matrix to file (default=False).
14+
15+
Returns: A labelled confusion matrix plot comparing y_true and y_pred.
16+
17+
### Example usage:
18+
19+
> """make_confusion_matrix(y_true=test_labels, # ground truth test labels
20+
y_pred=y_preds, # predicted labels
21+
classes=class_names, # array of class label names
22+
figsize=(15, 15),
23+
text_size=10)"""
24+
25+
#### CODE BY ZeroToMastery TensorFlow course.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import itertools
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
from sklearn.metrics import confusion_matrix
5+
6+
def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False):
7+
# Create the confustion matrix
8+
cm = confusion_matrix(y_true, y_pred)
9+
cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it
10+
n_classes = cm.shape[0] # find the number of classes we're dealing with
11+
12+
# Plot the figure and make it pretty
13+
fig, ax = plt.subplots(figsize=figsize)
14+
cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better
15+
fig.colorbar(cax)
16+
17+
# Are there a list of classes?
18+
if classes:
19+
labels = classes
20+
else:
21+
labels = np.arange(cm.shape[0])
22+
23+
# Label the axes
24+
ax.set(title="Confusion Matrix",
25+
xlabel="Predicted label",
26+
ylabel="True label",
27+
xticks=np.arange(n_classes), # create enough axis slots for each class
28+
yticks=np.arange(n_classes),
29+
xticklabels=labels, # axes will labeled with class names (if they exist) or ints
30+
yticklabels=labels)
31+
32+
# Make x-axis labels appear on bottom
33+
ax.xaxis.set_label_position("bottom")
34+
ax.xaxis.tick_bottom()
35+
36+
# Set the threshold for different colors
37+
threshold = (cm.max() + cm.min()) / 2.
38+
39+
# Plot the text on each cell
40+
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
41+
if norm:
42+
plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)",
43+
horizontalalignment="center",
44+
color="white" if cm[i, j] > threshold else "black",
45+
size=text_size)
46+
else:
47+
plt.text(j, i, f"{cm[i, j]}",
48+
horizontalalignment="center",
49+
color="white" if cm[i, j] > threshold else "black",
50+
size=text_size)
51+
52+
# Save the figure to the current working directory
53+
if savefig:
54+
fig.savefig("confusion_matrix.png")

0 commit comments

Comments
 (0)