|
| 1 | +--- |
| 2 | +jupyter: |
| 3 | + jupytext: |
| 4 | + notebook_metadata_filter: all |
| 5 | + text_representation: |
| 6 | + extension: .md |
| 7 | + format_name: markdown |
| 8 | + format_version: '1.1' |
| 9 | + jupytext_version: 1.1.1 |
| 10 | + kernelspec: |
| 11 | + display_name: Python 3 |
| 12 | + language: python |
| 13 | + name: python3 |
| 14 | + language_info: |
| 15 | + codemirror_mode: |
| 16 | + name: ipython |
| 17 | + version: 3 |
| 18 | + file_extension: .py |
| 19 | + mimetype: text/x-python |
| 20 | + name: python |
| 21 | + nbconvert_exporter: python |
| 22 | + pygments_lexer: ipython3 |
| 23 | + version: 3.7.6 |
| 24 | + plotly: |
| 25 | + description: Interpret the results of your classification using Receiver Operating |
| 26 | + Characteristics (ROC) and Precision-Recall (PR) Curves using Plotly on Python. |
| 27 | + display_as: ai_ml |
| 28 | + language: python |
| 29 | + layout: base |
| 30 | + name: ROC and PR Curves |
| 31 | + order: 3 |
| 32 | + page_type: example_index |
| 33 | + permalink: python/roc-and-pr-curves/ |
| 34 | + thumbnail: thumbnail/ml-roc-pr.png |
| 35 | +--- |
| 36 | + |
| 37 | +## Basic Binary ROC Curve |
| 38 | + |
| 39 | +```python |
| 40 | +import plotly.express as px |
| 41 | +from sklearn.linear_model import LogisticRegression |
| 42 | +from sklearn.metrics import roc_curve, auc |
| 43 | +from sklearn.datasets import make_classification |
| 44 | + |
| 45 | +X, y = make_classification(n_samples=500, random_state=0) |
| 46 | + |
| 47 | +model = LogisticRegression() |
| 48 | +model.fit(X, y) |
| 49 | +y_score = model.predict_proba(X)[:, 1] |
| 50 | + |
| 51 | +fpr, tpr, thresholds = roc_curve(y, y_score) |
| 52 | + |
| 53 | +fig = px.area( |
| 54 | + x=fpr, y=tpr, |
| 55 | + title=f'ROC Curve (AUC={auc(fpr, tpr):.4f})', |
| 56 | + labels=dict(x='False Positive Rate', y='True Positive Rate') |
| 57 | +) |
| 58 | +fig.add_shape( |
| 59 | + type='line', line=dict(dash='dash'), |
| 60 | + x0=0, x1=1, y0=0, y1=1 |
| 61 | +) |
| 62 | +fig.show() |
| 63 | +``` |
| 64 | + |
| 65 | +## Multiclass ROC Curve |
| 66 | + |
| 67 | +When you have more than 2 classes, you will need to plot the ROC curve for each class separately. Make sure that you use a [one-versus-rest](https://scikit-learn.org/stable/modules/multiclass.html#one-vs-the-rest) model, or make sure that your problem has a [multi-label](https://scikit-learn.org/stable/modules/multiclass.html#multilabel-classification-format) format; otherwise, your ROC curve might not return the expected results. |
| 68 | + |
| 69 | +```python |
| 70 | +import numpy as np |
| 71 | +import pandas as pd |
| 72 | +from sklearn.linear_model import LogisticRegression |
| 73 | +from sklearn.metrics import roc_curve, roc_auc_score |
| 74 | +import plotly.graph_objects as go |
| 75 | +import plotly.express as px |
| 76 | + |
| 77 | +np.random.seed(0) |
| 78 | + |
| 79 | +# Artificially add noise to make task harder |
| 80 | +df = px.data.iris() |
| 81 | +samples = df.species.sample(n=50, random_state=0) |
| 82 | +np.random.shuffle(samples.values) |
| 83 | +df.loc[samples.index, 'species'] = samples.values |
| 84 | + |
| 85 | +# Define the inputs and outputs |
| 86 | +X = df.drop(columns=['species', 'species_id']) |
| 87 | +y = df['species'] |
| 88 | +y_onehot = pd.get_dummies(y, columns=model.classes_) |
| 89 | + |
| 90 | +# Fit the model |
| 91 | +model = LogisticRegression(max_iter=200) |
| 92 | +model.fit(X, y) |
| 93 | +y_scores = model.predict_proba(X) |
| 94 | + |
| 95 | +# Create an empty figure, and iteratively add new lines |
| 96 | +# every time we compute a new class |
| 97 | +fig = go.Figure() |
| 98 | +fig.add_shape( |
| 99 | + type='line', line=dict(dash='dash'), |
| 100 | + x0=0, x1=1, y0=0, y1=1 |
| 101 | +) |
| 102 | + |
| 103 | +for i in range(y_scores.shape[1]): |
| 104 | + y_true = y_onehot.iloc[:, i] |
| 105 | + y_score = y_scores[:, i] |
| 106 | + |
| 107 | + fpr, tpr, _ = roc_curve(y_true, y_score) |
| 108 | + auc_score = roc_auc_score(y_true, y_score) |
| 109 | + |
| 110 | + name = f"{y_onehot.columns[i]} (AUC={auc_score:.2f})" |
| 111 | + fig.add_trace(go.Scatter(x=fpr, y=tpr, name=name, mode='lines')) |
| 112 | + |
| 113 | +fig.update_layout( |
| 114 | + xaxis_title='False Positive Rate', |
| 115 | + yaxis_title='True Positive Rate' |
| 116 | +) |
| 117 | +fig.show() |
| 118 | +``` |
| 119 | + |
| 120 | +## Precision-Recall Curves |
| 121 | + |
| 122 | +Plotting the PR curve is very similar to plotting the ROC curve. The following examples are slightly modified from the previous examples: |
| 123 | + |
| 124 | +```python |
| 125 | +import plotly.express as px |
| 126 | +from sklearn.linear_model import LogisticRegression |
| 127 | +from sklearn.metrics import precision_recall_curve, auc |
| 128 | +from sklearn.datasets import make_classification |
| 129 | + |
| 130 | +X, y = make_classification(n_samples=500, random_state=0) |
| 131 | + |
| 132 | +model = LogisticRegression() |
| 133 | +model.fit(X, y) |
| 134 | +y_score = model.predict_proba(X)[:, 1] |
| 135 | + |
| 136 | +precision, recall, thresholds = precision_recall_curve(y, y_score) |
| 137 | + |
| 138 | +fig = px.area( |
| 139 | + x=recall, y=precision, |
| 140 | + title=f'Precision-Recall Curve (AUC={auc(fpr, tpr):.4f})', |
| 141 | + labels=dict(x='Recall', y='Precision') |
| 142 | +) |
| 143 | +fig.add_shape( |
| 144 | + type='line', line=dict(dash='dash'), |
| 145 | + x0=0, x1=1, y0=1, y1=0 |
| 146 | +) |
| 147 | +fig.show() |
| 148 | +``` |
| 149 | + |
| 150 | +In this example, we use the [average precision](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html) metric, which is an alternative scoring method to the area under the PR curve. |
| 151 | + |
| 152 | +```python |
| 153 | +import numpy as np |
| 154 | +import pandas as pd |
| 155 | +from sklearn.linear_model import LogisticRegression |
| 156 | +from sklearn.metrics import precision_recall_curve, average_precision_score |
| 157 | +import plotly.graph_objects as go |
| 158 | +import plotly.express as px |
| 159 | + |
| 160 | +np.random.seed(0) |
| 161 | + |
| 162 | +# Artificially add noise to make task harder |
| 163 | +df = px.data.iris() |
| 164 | +samples = df.species.sample(n=30, random_state=0) |
| 165 | +np.random.shuffle(samples.values) |
| 166 | +df.loc[samples.index, 'species'] = samples.values |
| 167 | + |
| 168 | +# Define the inputs and outputs |
| 169 | +X = df.drop(columns=['species', 'species_id']) |
| 170 | +y = df['species'] |
| 171 | +y_onehot = pd.get_dummies(y, columns=model.classes_) |
| 172 | + |
| 173 | +# Fit the model |
| 174 | +model = LogisticRegression(max_iter=200) |
| 175 | +model.fit(X, y) |
| 176 | +y_scores = model.predict_proba(X) |
| 177 | + |
| 178 | +# Create an empty figure, and iteratively add new lines |
| 179 | +# every time we compute a new class |
| 180 | +fig = go.Figure() |
| 181 | +fig.add_shape( |
| 182 | + type='line', line=dict(dash='dash'), |
| 183 | + x0=0, x1=1, y0=1, y1=0 |
| 184 | +) |
| 185 | + |
| 186 | +for i in range(y_scores.shape[1]): |
| 187 | + y_true = y_onehot.iloc[:, i] |
| 188 | + y_score = y_scores[:, i] |
| 189 | + |
| 190 | + precision, recall, _ = precision_recall_curve(y_true, y_score) |
| 191 | + auc_score = average_precision_score(y_true, y_score) |
| 192 | + |
| 193 | + name = f"{y_onehot.columns[i]} (AP={auc_score:.2f})" |
| 194 | + fig.add_trace(go.Scatter(x=recall, y=precision, name=name, mode='lines')) |
| 195 | + |
| 196 | +fig.update_layout( |
| 197 | + xaxis_title='Recall', |
| 198 | + yaxis_title='Precision' |
| 199 | +) |
| 200 | +fig.show() |
| 201 | +``` |
0 commit comments