Skip to content

Commit 77a3d82

Browse files
author
xhlulu
committed
ML Docs: Start ROC/PR section
1 parent 3857aea commit 77a3d82

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed

doc/python/ml-roc-pr.md

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
---
2+
jupyter:
3+
jupytext:
4+
notebook_metadata_filter: all
5+
text_representation:
6+
extension: .md
7+
format_name: markdown
8+
format_version: '1.1'
9+
jupytext_version: 1.1.1
10+
kernelspec:
11+
display_name: Python 3
12+
language: python
13+
name: python3
14+
language_info:
15+
codemirror_mode:
16+
name: ipython
17+
version: 3
18+
file_extension: .py
19+
mimetype: text/x-python
20+
name: python
21+
nbconvert_exporter: python
22+
pygments_lexer: ipython3
23+
version: 3.7.6
24+
plotly:
25+
description: Interpret the results of your classification using Receiver Operating
26+
Characteristics (ROC) and Precision-Recall (PR) Curves using Plotly on Python.
27+
display_as: ai_ml
28+
language: python
29+
layout: base
30+
name: ROC and PR Curves
31+
order: 3
32+
page_type: example_index
33+
permalink: python/roc-and-pr-curves/
34+
thumbnail: thumbnail/ml-roc-pr.png
35+
---
36+
37+
## Basic Binary ROC Curve
38+
39+
```python
40+
import plotly.express as px
41+
from sklearn.linear_model import LogisticRegression
42+
from sklearn.metrics import roc_curve, auc
43+
from sklearn.datasets import make_classification
44+
45+
X, y = make_classification(n_samples=500, random_state=0)
46+
47+
model = LogisticRegression()
48+
model.fit(X, y)
49+
y_score = model.predict_proba(X)[:, 1]
50+
51+
fpr, tpr, thresholds = roc_curve(y, y_score)
52+
53+
fig = px.area(
54+
x=fpr, y=tpr,
55+
title=f'ROC Curve (AUC={auc(fpr, tpr):.4f})',
56+
labels=dict(x='False Positive Rate', y='True Positive Rate')
57+
)
58+
fig.add_shape(
59+
type='line', line=dict(dash='dash'),
60+
x0=0, x1=1, y0=0, y1=1
61+
)
62+
fig.show()
63+
```
64+
65+
## Multiclass ROC Curve
66+
67+
When you have more than 2 classes, you will need to plot the ROC curve for each class separately. Make sure that you use a [one-versus-rest](https://scikit-learn.org/stable/modules/multiclass.html#one-vs-the-rest) model, or make sure that your problem has a [multi-label](https://scikit-learn.org/stable/modules/multiclass.html#multilabel-classification-format) format; otherwise, your ROC curve might not return the expected results.
68+
69+
```python
70+
import numpy as np
71+
import pandas as pd
72+
from sklearn.linear_model import LogisticRegression
73+
from sklearn.metrics import roc_curve, roc_auc_score
74+
import plotly.graph_objects as go
75+
import plotly.express as px
76+
77+
np.random.seed(0)
78+
79+
# Artificially add noise to make task harder
80+
df = px.data.iris()
81+
samples = df.species.sample(n=50, random_state=0)
82+
np.random.shuffle(samples.values)
83+
df.loc[samples.index, 'species'] = samples.values
84+
85+
# Define the inputs and outputs
86+
X = df.drop(columns=['species', 'species_id'])
87+
y = df['species']
88+
y_onehot = pd.get_dummies(y, columns=model.classes_)
89+
90+
# Fit the model
91+
model = LogisticRegression(max_iter=200)
92+
model.fit(X, y)
93+
y_scores = model.predict_proba(X)
94+
95+
# Create an empty figure, and iteratively add new lines
96+
# every time we compute a new class
97+
fig = go.Figure()
98+
fig.add_shape(
99+
type='line', line=dict(dash='dash'),
100+
x0=0, x1=1, y0=0, y1=1
101+
)
102+
103+
for i in range(y_scores.shape[1]):
104+
y_true = y_onehot.iloc[:, i]
105+
y_score = y_scores[:, i]
106+
107+
fpr, tpr, _ = roc_curve(y_true, y_score)
108+
auc_score = roc_auc_score(y_true, y_score)
109+
110+
name = f"{y_onehot.columns[i]} (AUC={auc_score:.2f})"
111+
fig.add_trace(go.Scatter(x=fpr, y=tpr, name=name, mode='lines'))
112+
113+
fig.update_layout(
114+
xaxis_title='False Positive Rate',
115+
yaxis_title='True Positive Rate'
116+
)
117+
fig.show()
118+
```
119+
120+
## Precision-Recall Curves
121+
122+
Plotting the PR curve is very similar to plotting the ROC curve. The following examples are slightly modified from the previous examples:
123+
124+
```python
125+
import plotly.express as px
126+
from sklearn.linear_model import LogisticRegression
127+
from sklearn.metrics import precision_recall_curve, auc
128+
from sklearn.datasets import make_classification
129+
130+
X, y = make_classification(n_samples=500, random_state=0)
131+
132+
model = LogisticRegression()
133+
model.fit(X, y)
134+
y_score = model.predict_proba(X)[:, 1]
135+
136+
precision, recall, thresholds = precision_recall_curve(y, y_score)
137+
138+
fig = px.area(
139+
x=recall, y=precision,
140+
title=f'Precision-Recall Curve (AUC={auc(fpr, tpr):.4f})',
141+
labels=dict(x='Recall', y='Precision')
142+
)
143+
fig.add_shape(
144+
type='line', line=dict(dash='dash'),
145+
x0=0, x1=1, y0=1, y1=0
146+
)
147+
fig.show()
148+
```
149+
150+
In this example, we use the [average precision](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html) metric, which is an alternative scoring method to the area under the PR curve.
151+
152+
```python
153+
import numpy as np
154+
import pandas as pd
155+
from sklearn.linear_model import LogisticRegression
156+
from sklearn.metrics import precision_recall_curve, average_precision_score
157+
import plotly.graph_objects as go
158+
import plotly.express as px
159+
160+
np.random.seed(0)
161+
162+
# Artificially add noise to make task harder
163+
df = px.data.iris()
164+
samples = df.species.sample(n=30, random_state=0)
165+
np.random.shuffle(samples.values)
166+
df.loc[samples.index, 'species'] = samples.values
167+
168+
# Define the inputs and outputs
169+
X = df.drop(columns=['species', 'species_id'])
170+
y = df['species']
171+
y_onehot = pd.get_dummies(y, columns=model.classes_)
172+
173+
# Fit the model
174+
model = LogisticRegression(max_iter=200)
175+
model.fit(X, y)
176+
y_scores = model.predict_proba(X)
177+
178+
# Create an empty figure, and iteratively add new lines
179+
# every time we compute a new class
180+
fig = go.Figure()
181+
fig.add_shape(
182+
type='line', line=dict(dash='dash'),
183+
x0=0, x1=1, y0=1, y1=0
184+
)
185+
186+
for i in range(y_scores.shape[1]):
187+
y_true = y_onehot.iloc[:, i]
188+
y_score = y_scores[:, i]
189+
190+
precision, recall, _ = precision_recall_curve(y_true, y_score)
191+
auc_score = average_precision_score(y_true, y_score)
192+
193+
name = f"{y_onehot.columns[i]} (AP={auc_score:.2f})"
194+
fig.add_trace(go.Scatter(x=recall, y=precision, name=name, mode='lines'))
195+
196+
fig.update_layout(
197+
xaxis_title='Recall',
198+
yaxis_title='Precision'
199+
)
200+
fig.show()
201+
```

0 commit comments

Comments
 (0)