1+ import itertools
2+ import matplotlib .pyplot as plt
3+ import numpy as np
4+ from sklearn .metrics import confusion_matrix
5+
6+ def make_confusion_matrix (y_true , y_pred , classes = None , figsize = (10 , 10 ), text_size = 15 , norm = False , savefig = False ):
7+ # Create the confustion matrix
8+ cm = confusion_matrix (y_true , y_pred )
9+ cm_norm = cm .astype ("float" ) / cm .sum (axis = 1 )[:, np .newaxis ] # normalize it
10+ n_classes = cm .shape [0 ] # find the number of classes we're dealing with
11+
12+ # Plot the figure and make it pretty
13+ fig , ax = plt .subplots (figsize = figsize )
14+ cax = ax .matshow (cm , cmap = plt .cm .Blues ) # colors will represent how 'correct' a class is, darker == better
15+ fig .colorbar (cax )
16+
17+ # Are there a list of classes?
18+ if classes :
19+ labels = classes
20+ else :
21+ labels = np .arange (cm .shape [0 ])
22+
23+ # Label the axes
24+ ax .set (title = "Confusion Matrix" ,
25+ xlabel = "Predicted label" ,
26+ ylabel = "True label" ,
27+ xticks = np .arange (n_classes ), # create enough axis slots for each class
28+ yticks = np .arange (n_classes ),
29+ xticklabels = labels , # axes will labeled with class names (if they exist) or ints
30+ yticklabels = labels )
31+
32+ # Make x-axis labels appear on bottom
33+ ax .xaxis .set_label_position ("bottom" )
34+ ax .xaxis .tick_bottom ()
35+
36+ # Set the threshold for different colors
37+ threshold = (cm .max () + cm .min ()) / 2.
38+
39+ # Plot the text on each cell
40+ for i , j in itertools .product (range (cm .shape [0 ]), range (cm .shape [1 ])):
41+ if norm :
42+ plt .text (j , i , f"{ cm [i , j ]} ({ cm_norm [i , j ]* 100 :.1f} %)" ,
43+ horizontalalignment = "center" ,
44+ color = "white" if cm [i , j ] > threshold else "black" ,
45+ size = text_size )
46+ else :
47+ plt .text (j , i , f"{ cm [i , j ]} " ,
48+ horizontalalignment = "center" ,
49+ color = "white" if cm [i , j ] > threshold else "black" ,
50+ size = text_size )
51+
52+ # Save the figure to the current working directory
53+ if savefig :
54+ fig .savefig ("confusion_matrix.png" )
0 commit comments