Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6238968

Browse filesBrowse files
ENH PrecisionRecallDisplay add option to plot chance level (#26019)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 51bd562 commit 6238968
Copy full SHA for 6238968

File tree

4 files changed

+221
-6
lines changed
Filter options

4 files changed

+221
-6
lines changed

‎doc/whats_new/v1.3.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.3.rst
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,12 @@ Changelog
411411
level. This line is exposed in the `chance_level_` attribute.
412412
:pr:`25987` by :user:`Yao Xiao <Charlie-XIAO>`.
413413

414+
- |Enhancement| :meth:`metrics.PrecisionRecallDisplay.from_estimator` and
415+
:meth:`metrics.PrecisionRecallDisplay.from_predictions` now accept two new
416+
keywords, `plot_chance_level` and `chance_level_kw` to plot the baseline
417+
chance level. This line is exposed in the `chance_level_` attribute.
418+
:pr:`26019` by :user:`Yao Xiao <Charlie-XIAO>`.
419+
414420
- |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are
415421
not normalized, instead of actually normalizing them in the metric. Starting from
416422
1.5 this will raise an error. :pr:`25299` by :user:`Omar Salman <OmarManzoor`.

‎examples/model_selection/plot_precision_recall.py

Copy file name to clipboardExpand all lines: examples/model_selection/plot_precision_recall.py
+8-3Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@
142142
from sklearn.metrics import PrecisionRecallDisplay
143143

144144
display = PrecisionRecallDisplay.from_estimator(
145-
classifier, X_test, y_test, name="LinearSVC"
145+
classifier, X_test, y_test, name="LinearSVC", plot_chance_level=True
146146
)
147147
_ = display.ax_.set_title("2-class Precision-Recall curve")
148148

@@ -152,7 +152,9 @@
152152
# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions`.
153153
y_score = classifier.decision_function(X_test)
154154

155-
display = PrecisionRecallDisplay.from_predictions(y_test, y_score, name="LinearSVC")
155+
display = PrecisionRecallDisplay.from_predictions(
156+
y_test, y_score, name="LinearSVC", plot_chance_level=True
157+
)
156158
_ = display.ax_.set_title("2-class Precision-Recall curve")
157159

158160
# %%
@@ -214,12 +216,15 @@
214216
# %%
215217
# Plot the micro-averaged Precision-Recall curve
216218
# ..............................................
219+
from collections import Counter
220+
217221
display = PrecisionRecallDisplay(
218222
recall=recall["micro"],
219223
precision=precision["micro"],
220224
average_precision=average_precision["micro"],
225+
prevalence_pos_label=Counter(Y_test.ravel())[1] / Y_test.size,
221226
)
222-
display.plot()
227+
display.plot(plot_chance_level=True)
223228
_ = display.ax_.set_title("Micro-averaged over all classes")
224229

225230
# %%

‎sklearn/metrics/_plot/precision_recall_curve.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_plot/precision_recall_curve.py
+110-3Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import Counter
2+
13
from .. import average_precision_score
24
from .. import precision_recall_curve
35
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
@@ -34,11 +36,23 @@ class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin):
3436
3537
.. versionadded:: 0.24
3638
39+
prevalence_pos_label : float, default=None
40+
The prevalence of the positive label. It is used for plotting the
41+
chance level line. If None, the chance level line will not be plotted
42+
even if `plot_chance_level` is set to True when plotting.
43+
44+
.. versionadded:: 1.3
45+
3746
Attributes
3847
----------
3948
line_ : matplotlib Artist
4049
Precision recall curve.
4150
51+
chance_level_ : matplotlib Artist or None
52+
The chance level line. It is `None` if the chance level is not plotted.
53+
54+
.. versionadded:: 1.3
55+
4256
ax_ : matplotlib Axes
4357
Axes with precision recall curve.
4458
@@ -96,14 +110,24 @@ def __init__(
96110
average_precision=None,
97111
estimator_name=None,
98112
pos_label=None,
113+
prevalence_pos_label=None,
99114
):
100115
self.estimator_name = estimator_name
101116
self.precision = precision
102117
self.recall = recall
103118
self.average_precision = average_precision
104119
self.pos_label = pos_label
120+
self.prevalence_pos_label = prevalence_pos_label
105121

106-
def plot(self, ax=None, *, name=None, **kwargs):
122+
def plot(
123+
self,
124+
ax=None,
125+
*,
126+
name=None,
127+
plot_chance_level=False,
128+
chance_level_kw=None,
129+
**kwargs,
130+
):
107131
"""Plot visualization.
108132
109133
Extra keyword arguments will be passed to matplotlib's `plot`.
@@ -118,6 +142,19 @@ def plot(self, ax=None, *, name=None, **kwargs):
118142
Name of precision recall curve for labeling. If `None`, use
119143
`estimator_name` if not `None`, otherwise no labeling is shown.
120144
145+
plot_chance_level : bool, default=False
146+
Whether to plot the chance level. The chance level is the prevalence
147+
of the positive label computed from the data passed during
148+
:meth:`from_estimator` or :meth:`from_predictions` call.
149+
150+
.. versionadded:: 1.3
151+
152+
chance_level_kw : dict, default=None
153+
Keyword arguments to be passed to matplotlib's `plot` for rendering
154+
the chance level line.
155+
156+
.. versionadded:: 1.3
157+
121158
**kwargs : dict
122159
Keyword arguments to be passed to matplotlib's `plot`.
123160
@@ -149,6 +186,7 @@ def plot(self, ax=None, *, name=None, **kwargs):
149186
line_kwargs.update(**kwargs)
150187

151188
(self.line_,) = self.ax_.plot(self.recall, self.precision, **line_kwargs)
189+
152190
info_pos_label = (
153191
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
154192
)
@@ -157,7 +195,34 @@ def plot(self, ax=None, *, name=None, **kwargs):
157195
ylabel = "Precision" + info_pos_label
158196
self.ax_.set(xlabel=xlabel, ylabel=ylabel)
159197

160-
if "label" in line_kwargs:
198+
if plot_chance_level:
199+
if self.prevalence_pos_label is None:
200+
raise ValueError(
201+
"You must provide prevalence_pos_label when constructing the "
202+
"PrecisionRecallDisplay object in order to plot the chance "
203+
"level line. Alternatively, you may use "
204+
"PrecisionRecallDisplay.from_estimator or "
205+
"PrecisionRecallDisplay.from_predictions "
206+
"to automatically set prevalence_pos_label"
207+
)
208+
209+
chance_level_line_kw = {
210+
"label": f"Chance level (AP = {self.prevalence_pos_label:0.2f})",
211+
"color": "k",
212+
"linestyle": "--",
213+
}
214+
if chance_level_kw is not None:
215+
chance_level_line_kw.update(chance_level_kw)
216+
217+
(self.chance_level_,) = self.ax_.plot(
218+
(0, 1),
219+
(self.prevalence_pos_label, self.prevalence_pos_label),
220+
**chance_level_line_kw,
221+
)
222+
else:
223+
self.chance_level_ = None
224+
225+
if "label" in line_kwargs or plot_chance_level:
161226
self.ax_.legend(loc="lower left")
162227

163228
return self
@@ -175,6 +240,8 @@ def from_estimator(
175240
response_method="auto",
176241
name=None,
177242
ax=None,
243+
plot_chance_level=False,
244+
chance_level_kw=None,
178245
**kwargs,
179246
):
180247
"""Plot precision-recall curve given an estimator and some data.
@@ -219,6 +286,19 @@ def from_estimator(
219286
ax : matplotlib axes, default=None
220287
Axes object to plot on. If `None`, a new figure and axes is created.
221288
289+
plot_chance_level : bool, default=False
290+
Whether to plot the chance level. The chance level is the prevalence
291+
of the positive label computed from the data passed during
292+
:meth:`from_estimator` or :meth:`from_predictions` call.
293+
294+
.. versionadded:: 1.3
295+
296+
chance_level_kw : dict, default=None
297+
Keyword arguments to be passed to matplotlib's `plot` for rendering
298+
the chance level line.
299+
300+
.. versionadded:: 1.3
301+
222302
**kwargs : dict
223303
Keyword arguments to be passed to matplotlib's `plot`.
224304
@@ -277,6 +357,8 @@ def from_estimator(
277357
pos_label=pos_label,
278358
drop_intermediate=drop_intermediate,
279359
ax=ax,
360+
plot_chance_level=plot_chance_level,
361+
chance_level_kw=chance_level_kw,
280362
**kwargs,
281363
)
282364

@@ -291,6 +373,8 @@ def from_predictions(
291373
drop_intermediate=False,
292374
name=None,
293375
ax=None,
376+
plot_chance_level=False,
377+
chance_level_kw=None,
294378
**kwargs,
295379
):
296380
"""Plot precision-recall curve given binary class predictions.
@@ -324,6 +408,19 @@ def from_predictions(
324408
ax : matplotlib axes, default=None
325409
Axes object to plot on. If `None`, a new figure and axes is created.
326410
411+
plot_chance_level : bool, default=False
412+
Whether to plot the chance level. The chance level is the prevalence
413+
of the positive label computed from the data passed during
414+
:meth:`from_estimator` or :meth:`from_predictions` call.
415+
416+
.. versionadded:: 1.3
417+
418+
chance_level_kw : dict, default=None
419+
Keyword arguments to be passed to matplotlib's `plot` for rendering
420+
the chance level line.
421+
422+
.. versionadded:: 1.3
423+
327424
**kwargs : dict
328425
Keyword arguments to be passed to matplotlib's `plot`.
329426
@@ -381,12 +478,22 @@ def from_predictions(
381478
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
382479
)
383480

481+
class_count = Counter(y_true)
482+
prevalence_pos_label = class_count[pos_label] / sum(class_count.values())
483+
384484
viz = PrecisionRecallDisplay(
385485
precision=precision,
386486
recall=recall,
387487
average_precision=average_precision,
388488
estimator_name=name,
389489
pos_label=pos_label,
490+
prevalence_pos_label=prevalence_pos_label,
390491
)
391492

392-
return viz.plot(ax=ax, name=name, **kwargs)
493+
return viz.plot(
494+
ax=ax,
495+
name=name,
496+
plot_chance_level=plot_chance_level,
497+
chance_level_kw=chance_level_kw,
498+
**kwargs,
499+
)

‎sklearn/metrics/_plot/tests/test_precision_recall_display.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_plot/tests/test_precision_recall_display.py
+97Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import Counter
2+
13
import numpy as np
24
import pytest
35

@@ -76,6 +78,52 @@ def test_precision_recall_display_plotting(
7678
assert display.line_.get_label() == expected_label
7779
assert display.line_.get_alpha() == pytest.approx(0.8)
7880

81+
# Check that the chance level line is not plotted by default
82+
assert display.chance_level_ is None
83+
84+
85+
@pytest.mark.parametrize("chance_level_kw", [None, {"color": "r"}])
86+
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
87+
def test_precision_recall_chance_level_line(
88+
pyplot,
89+
chance_level_kw,
90+
constructor_name,
91+
):
92+
"""Check the chance level line plotting behavior."""
93+
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
94+
pos_prevalence = Counter(y)[1] / len(y)
95+
96+
lr = LogisticRegression()
97+
y_pred = lr.fit(X, y).predict_proba(X)[:, 1]
98+
99+
if constructor_name == "from_estimator":
100+
display = PrecisionRecallDisplay.from_estimator(
101+
lr,
102+
X,
103+
y,
104+
plot_chance_level=True,
105+
chance_level_kw=chance_level_kw,
106+
)
107+
else:
108+
display = PrecisionRecallDisplay.from_predictions(
109+
y,
110+
y_pred,
111+
plot_chance_level=True,
112+
chance_level_kw=chance_level_kw,
113+
)
114+
115+
import matplotlib as mpl # noqa
116+
117+
assert isinstance(display.chance_level_, mpl.lines.Line2D)
118+
assert tuple(display.chance_level_.get_xdata()) == (0, 1)
119+
assert tuple(display.chance_level_.get_ydata()) == (pos_prevalence, pos_prevalence)
120+
121+
# Checking for chance level line styles
122+
if chance_level_kw is None:
123+
assert display.chance_level_.get_color() == "k"
124+
else:
125+
assert display.chance_level_.get_color() == "r"
126+
79127

80128
@pytest.mark.parametrize(
81129
"constructor_name, default_label",
@@ -256,3 +304,52 @@ def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_meth
256304
avg_prec_limit = 0.95
257305
assert display.average_precision > avg_prec_limit
258306
assert -np.trapz(display.precision, display.recall) > avg_prec_limit
307+
308+
309+
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
310+
def test_precision_recall_prevalence_pos_label_reusable(pyplot, constructor_name):
311+
# Check that even if one passes plot_chance_level=False the first time
312+
# one can still call disp.plot with plot_chance_level=True and get the
313+
# chance level line
314+
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
315+
316+
lr = LogisticRegression()
317+
y_pred = lr.fit(X, y).predict_proba(X)[:, 1]
318+
319+
if constructor_name == "from_estimator":
320+
display = PrecisionRecallDisplay.from_estimator(
321+
lr, X, y, plot_chance_level=False
322+
)
323+
else:
324+
display = PrecisionRecallDisplay.from_predictions(
325+
y, y_pred, plot_chance_level=False
326+
)
327+
assert display.chance_level_ is None
328+
329+
import matplotlib as mpl # noqa
330+
331+
# When calling from_estimator or from_predictions,
332+
# prevalence_pos_label should have been set, so that directly
333+
# calling plot_chance_level=True should plot the chance level line
334+
display.plot(plot_chance_level=True)
335+
assert isinstance(display.chance_level_, mpl.lines.Line2D)
336+
337+
338+
def test_precision_recall_raise_no_prevalence(pyplot):
339+
# Check that raises correctly when plotting chance level with
340+
# no prvelance_pos_label is provided
341+
precision = np.array([1, 0.5, 0])
342+
recall = np.array([0, 0.5, 1])
343+
display = PrecisionRecallDisplay(precision, recall)
344+
345+
msg = (
346+
"You must provide prevalence_pos_label when constructing the "
347+
"PrecisionRecallDisplay object in order to plot the chance "
348+
"level line. Alternatively, you may use "
349+
"PrecisionRecallDisplay.from_estimator or "
350+
"PrecisionRecallDisplay.from_predictions "
351+
"to automatically set prevalence_pos_label"
352+
)
353+
354+
with pytest.raises(ValueError, match=msg):
355+
display.plot(plot_chance_level=True)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.