Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 59dd128

Browse filesBrowse files
ENH despine keyword for ROC and PR curves (#26367)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent 1177cad commit 59dd128
Copy full SHA for 59dd128

File tree

Expand file treeCollapse file tree

9 files changed

+142
-4
lines changed
Filter options
Expand file treeCollapse file tree

9 files changed

+142
-4
lines changed
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- :meth:`metrics.RocCurveDisplay.from_estimator`,
2+
:meth:`metrics.RocCurveDisplay.from_predictions`,
3+
:meth:`metrics.PrecisionRecallDisplay.from_estimator`, and
4+
:meth:`metrics.PrecisionRecallDisplay.from_predictions` now accept a new keyword
5+
`despine` to remove the top and right spines of the plot in order to make it clearer.
6+
By :user:`Yao Xiao <Charlie-XIAO>`.

‎examples/model_selection/plot_precision_recall.py

Copy file name to clipboardExpand all lines: examples/model_selection/plot_precision_recall.py
+6-4Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@
147147
from sklearn.metrics import PrecisionRecallDisplay
148148

149149
display = PrecisionRecallDisplay.from_estimator(
150-
classifier, X_test, y_test, name="LinearSVC", plot_chance_level=True
150+
classifier, X_test, y_test, name="LinearSVC", plot_chance_level=True, despine=True
151151
)
152152
_ = display.ax_.set_title("2-class Precision-Recall curve")
153153

@@ -158,7 +158,7 @@
158158
y_score = classifier.decision_function(X_test)
159159

160160
display = PrecisionRecallDisplay.from_predictions(
161-
y_test, y_score, name="LinearSVC", plot_chance_level=True
161+
y_test, y_score, name="LinearSVC", plot_chance_level=True, despine=True
162162
)
163163
_ = display.ax_.set_title("2-class Precision-Recall curve")
164164

@@ -228,7 +228,7 @@
228228
average_precision=average_precision["micro"],
229229
prevalence_pos_label=Counter(Y_test.ravel())[1] / Y_test.size,
230230
)
231-
display.plot(plot_chance_level=True)
231+
display.plot(plot_chance_level=True, despine=True)
232232
_ = display.ax_.set_title("Micro-averaged over all classes")
233233

234234
# %%
@@ -264,7 +264,9 @@
264264
precision=precision[i],
265265
average_precision=average_precision[i],
266266
)
267-
display.plot(ax=ax, name=f"Precision-recall for class {i}", color=color)
267+
display.plot(
268+
ax=ax, name=f"Precision-recall for class {i}", color=color, despine=True
269+
)
268270

269271
# add the legend for the iso-f1 curves
270272
handles, labels = display.ax_.get_legend_handles_labels()

‎examples/model_selection/plot_roc.py

Copy file name to clipboardExpand all lines: examples/model_selection/plot_roc.py
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
name=f"{class_of_interest} vs the rest",
132132
color="darkorange",
133133
plot_chance_level=True,
134+
despine=True,
134135
)
135136
_ = display.ax_.set(
136137
xlabel="False Positive Rate",
@@ -166,6 +167,7 @@
166167
name="micro-average OvR",
167168
color="darkorange",
168169
plot_chance_level=True,
170+
despine=True,
169171
)
170172
_ = display.ax_.set(
171173
xlabel="False Positive Rate",
@@ -285,6 +287,7 @@
285287
color=color,
286288
ax=ax,
287289
plot_chance_level=(class_id == 2),
290+
despine=True,
288291
)
289292

290293
_ = ax.set(
@@ -366,6 +369,7 @@
366369
ax=ax,
367370
name=f"{label_b} as positive class",
368371
plot_chance_level=True,
372+
despine=True,
369373
)
370374
ax.set(
371375
xlabel="False Positive Rate",

‎sklearn/metrics/_plot/precision_recall_curve.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_plot/precision_recall_curve.py
+24Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from ...utils._plotting import (
77
_BinaryClassifierCurveDisplayMixin,
8+
_despine,
89
_validate_style_kwargs,
910
)
1011
from .._ranking import average_precision_score, precision_recall_curve
@@ -131,6 +132,7 @@ def plot(
131132
name=None,
132133
plot_chance_level=False,
133134
chance_level_kw=None,
135+
despine=False,
134136
**kwargs,
135137
):
136138
"""Plot visualization.
@@ -160,6 +162,11 @@ def plot(
160162
161163
.. versionadded:: 1.3
162164
165+
despine : bool, default=False
166+
Whether to remove the top and right spines from the plot.
167+
168+
.. versionadded:: 1.6
169+
163170
**kwargs : dict
164171
Keyword arguments to be passed to matplotlib's `plot`.
165172
@@ -241,6 +248,9 @@ def plot(
241248
else:
242249
self.chance_level_ = None
243250

251+
if despine:
252+
_despine(self.ax_)
253+
244254
if "label" in line_kwargs or plot_chance_level:
245255
self.ax_.legend(loc="lower left")
246256

@@ -261,6 +271,7 @@ def from_estimator(
261271
ax=None,
262272
plot_chance_level=False,
263273
chance_level_kw=None,
274+
despine=False,
264275
**kwargs,
265276
):
266277
"""Plot precision-recall curve given an estimator and some data.
@@ -318,6 +329,11 @@ def from_estimator(
318329
319330
.. versionadded:: 1.3
320331
332+
despine : bool, default=False
333+
Whether to remove the top and right spines from the plot.
334+
335+
.. versionadded:: 1.6
336+
321337
**kwargs : dict
322338
Keyword arguments to be passed to matplotlib's `plot`.
323339
@@ -378,6 +394,7 @@ def from_estimator(
378394
ax=ax,
379395
plot_chance_level=plot_chance_level,
380396
chance_level_kw=chance_level_kw,
397+
despine=despine,
381398
**kwargs,
382399
)
383400

@@ -394,6 +411,7 @@ def from_predictions(
394411
ax=None,
395412
plot_chance_level=False,
396413
chance_level_kw=None,
414+
despine=False,
397415
**kwargs,
398416
):
399417
"""Plot precision-recall curve given binary class predictions.
@@ -440,6 +458,11 @@ def from_predictions(
440458
441459
.. versionadded:: 1.3
442460
461+
despine : bool, default=False
462+
Whether to remove the top and right spines from the plot.
463+
464+
.. versionadded:: 1.6
465+
443466
**kwargs : dict
444467
Keyword arguments to be passed to matplotlib's `plot`.
445468
@@ -514,5 +537,6 @@ def from_predictions(
514537
name=name,
515538
plot_chance_level=plot_chance_level,
516539
chance_level_kw=chance_level_kw,
540+
despine=despine,
517541
**kwargs,
518542
)

‎sklearn/metrics/_plot/roc_curve.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_plot/roc_curve.py
+24Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from ...utils._plotting import (
55
_BinaryClassifierCurveDisplayMixin,
6+
_despine,
67
_validate_style_kwargs,
78
)
89
from .._ranking import auc, roc_curve
@@ -95,6 +96,7 @@ def plot(
9596
name=None,
9697
plot_chance_level=False,
9798
chance_level_kw=None,
99+
despine=False,
98100
**kwargs,
99101
):
100102
"""Plot visualization.
@@ -122,6 +124,11 @@ def plot(
122124
123125
.. versionadded:: 1.3
124126
127+
despine : bool, default=False
128+
Whether to remove the top and right spines from the plot.
129+
130+
.. versionadded:: 1.6
131+
125132
**kwargs : dict
126133
Keyword arguments to be passed to matplotlib's `plot`.
127134
@@ -175,6 +182,9 @@ def plot(
175182
else:
176183
self.chance_level_ = None
177184

185+
if despine:
186+
_despine(self.ax_)
187+
178188
if "label" in line_kwargs or "label" in chance_level_kw:
179189
self.ax_.legend(loc="lower right")
180190

@@ -195,6 +205,7 @@ def from_estimator(
195205
ax=None,
196206
plot_chance_level=False,
197207
chance_level_kw=None,
208+
despine=False,
198209
**kwargs,
199210
):
200211
"""Create a ROC Curve display from an estimator.
@@ -249,6 +260,11 @@ def from_estimator(
249260
250261
.. versionadded:: 1.3
251262
263+
despine : bool, default=False
264+
Whether to remove the top and right spines from the plot.
265+
266+
.. versionadded:: 1.6
267+
252268
**kwargs : dict
253269
Keyword arguments to be passed to matplotlib's `plot`.
254270
@@ -299,6 +315,7 @@ def from_estimator(
299315
pos_label=pos_label,
300316
plot_chance_level=plot_chance_level,
301317
chance_level_kw=chance_level_kw,
318+
despine=despine,
302319
**kwargs,
303320
)
304321

@@ -315,6 +332,7 @@ def from_predictions(
315332
ax=None,
316333
plot_chance_level=False,
317334
chance_level_kw=None,
335+
despine=False,
318336
**kwargs,
319337
):
320338
"""Plot ROC curve given the true and predicted values.
@@ -365,6 +383,11 @@ def from_predictions(
365383
366384
.. versionadded:: 1.3
367385
386+
despine : bool, default=False
387+
Whether to remove the top and right spines from the plot.
388+
389+
.. versionadded:: 1.6
390+
368391
**kwargs : dict
369392
Additional keywords arguments passed to matplotlib `plot` function.
370393
@@ -423,5 +446,6 @@ def from_predictions(
423446
name=name,
424447
plot_chance_level=plot_chance_level,
425448
chance_level_kw=chance_level_kw,
449+
despine=despine,
426450
**kwargs,
427451
)

‎sklearn/metrics/_plot/tests/test_precision_recall_display.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_plot/tests/test_precision_recall_display.py
+27Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,30 @@ def test_precision_recall_raise_no_prevalence(pyplot):
353353

354354
with pytest.raises(ValueError, match=msg):
355355
display.plot(plot_chance_level=True)
356+
357+
358+
@pytest.mark.parametrize("despine", [True, False])
359+
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
360+
def test_plot_precision_recall_despine(pyplot, despine, constructor_name):
361+
# Check that the despine keyword is working correctly
362+
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
363+
364+
clf = LogisticRegression().fit(X, y)
365+
clf.fit(X, y)
366+
367+
y_pred = clf.decision_function(X)
368+
369+
# safe guard for the binary if/else construction
370+
assert constructor_name in ("from_estimator", "from_predictions")
371+
372+
if constructor_name == "from_estimator":
373+
display = PrecisionRecallDisplay.from_estimator(clf, X, y, despine=despine)
374+
else:
375+
display = PrecisionRecallDisplay.from_predictions(y, y_pred, despine=despine)
376+
377+
for s in ["top", "right"]:
378+
assert display.ax_.spines[s].get_visible() is not despine
379+
380+
if despine:
381+
for s in ["bottom", "left"]:
382+
assert display.ax_.spines[s].get_bounds() == (0, 1)

‎sklearn/metrics/_plot/tests/test_roc_curve_display.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_plot/tests/test_roc_curve_display.py
+27Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,30 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name):
334334

335335
assert display.roc_auc == pytest.approx(roc_auc_limit)
336336
assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit)
337+
338+
339+
@pytest.mark.parametrize("despine", [True, False])
340+
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
341+
def test_plot_roc_curve_despine(pyplot, data_binary, despine, constructor_name):
342+
# Check that the despine keyword is working correctly
343+
X, y = data_binary
344+
345+
lr = LogisticRegression().fit(X, y)
346+
lr.fit(X, y)
347+
348+
y_pred = lr.decision_function(X)
349+
350+
# safe guard for the binary if/else construction
351+
assert constructor_name in ("from_estimator", "from_predictions")
352+
353+
if constructor_name == "from_estimator":
354+
display = RocCurveDisplay.from_estimator(lr, X, y, despine=despine)
355+
else:
356+
display = RocCurveDisplay.from_predictions(y, y_pred, despine=despine)
357+
358+
for s in ["top", "right"]:
359+
assert display.ax_.spines[s].get_visible() is not despine
360+
361+
if despine:
362+
for s in ["bottom", "left"]:
363+
assert display.ax_.spines[s].get_bounds() == (0, 1)

‎sklearn/utils/_plotting.py

Copy file name to clipboardExpand all lines: sklearn/utils/_plotting.py
+14Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,17 @@ def _validate_style_kwargs(default_style_kwargs, user_style_kwargs):
163163
valid_style_kwargs[key] = user_style_kwargs[key]
164164

165165
return valid_style_kwargs
166+
167+
168+
def _despine(ax):
169+
"""Remove the top and right spines of the plot.
170+
171+
Parameters
172+
----------
173+
ax : matplotlib.axes.Axes
174+
The axes of the plot to despine.
175+
"""
176+
for s in ["top", "right"]:
177+
ax.spines[s].set_visible(False)
178+
for s in ["bottom", "left"]:
179+
ax.spines[s].set_bounds(0, 1)

‎sklearn/utils/tests/test_plotting.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_plotting.py
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from sklearn.utils._plotting import (
5+
_despine,
56
_interval_max_min_ratio,
67
_validate_score_name,
78
_validate_style_kwargs,
@@ -128,3 +129,12 @@ def test_validate_style_kwargs_error(default_kwargs, user_kwargs):
128129
"""Check that `validate_style_kwargs` raises TypeError"""
129130
with pytest.raises(TypeError):
130131
_validate_style_kwargs(default_kwargs, user_kwargs)
132+
133+
134+
def test_despine(pyplot):
135+
ax = pyplot.gca()
136+
_despine(ax)
137+
assert ax.spines["top"].get_visible() is False
138+
assert ax.spines["right"].get_visible() is False
139+
assert ax.spines["bottom"].get_bounds() == (0, 1)
140+
assert ax.spines["left"].get_bounds() == (0, 1)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.