Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

ENH Improve ROC curves visualization and add option to plot chance level #25972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions 5 doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ Changelog
curves.
:pr:`24668` by :user:`dberenbaum`.

- |Enhancement| :class:`RocCurveDisplay` now plots the ROC curve with both axes
limited to [0, 1] and a loosely dotted frame. There is also an additional
parameter `plot_chance_level` to determine whether to plot the chance level.
:pr:`25972` by :user:`Yao Xiao <Charlie-XIAO>`.

- |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are
not normalized, instead of actually normalizing them in the metric. Starting from
1.5 this will raise an error. :pr:`25299` by :user:`Omar Salman <OmarManzoor`.
Expand Down
37 changes: 35 additions & 2 deletions 37 sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class RocCurveDisplay:
line_ : matplotlib Artist
ROC Curve.

chance_level_ : matplotlib Artist
Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved
The chance level line or None if the chance level is not plotted.
Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved

ax_ : matplotlib Axes
Axes with ROC Curve.

Expand Down Expand Up @@ -81,7 +84,7 @@ def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=Non
self.roc_auc = roc_auc
self.pos_label = pos_label

def plot(self, ax=None, *, name=None, **kwargs):
def plot(self, ax=None, *, name=None, plot_chance_level=True, **kwargs):
Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved
"""Plot visualization.

Extra keyword arguments will be passed to matplotlib's ``plot``.
Expand All @@ -96,6 +99,9 @@ def plot(self, ax=None, *, name=None, **kwargs):
Name of ROC Curve for labeling. If `None`, use `estimator_name` if
not `None`, otherwise no labeling is shown.

plot_chance_level : bool, default=True
Whether to plot the chance level.

Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -123,6 +129,24 @@ def plot(self, ax=None, *, name=None, **kwargs):
if ax is None:
fig, ax = plt.subplots()

# Set limits of axes to [0, 1] and fix aspect ratio to squared
ax.set_xlim((0, 1))
ax.set_ylim((0, 1))
ax.set_aspect(1)
Comment on lines +132 to +135
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this change in another PR. We will need an additional entry in the changelog since we are fixing/improving the rendering.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be shared with PR and ROC curve


# Plot the frame in dotted line, so that the curve can be
# seen better when values are close to 0 or 1
for s in ["right", "left", "top", "bottom"]:
ax.spines[s].set_linestyle((0, (1, 5)))
ax.spines[s].set_linewidth(0.5)
Comment on lines +137 to +141
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. We can postpone the despining. It could also be shared between different type of plots. And we should be able to control it via some keywords.


if plot_chance_level:
(self.chance_level_,) = ax.plot(
(0, 1), (0, 1), linestyle="dotted", label="Chance level"
Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved
)
else:
self.chance_level_ = None

(self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs)
info_pos_label = (
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
Expand Down Expand Up @@ -152,6 +176,7 @@ def from_estimator(
pos_label=None,
name=None,
ax=None,
plot_chance_level=True,
Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""Create a ROC Curve display from an estimator.
Expand Down Expand Up @@ -195,6 +220,9 @@ def from_estimator(
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

plot_chance_level : bool, default=True
Whether to plot the chance level.
Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved

**kwargs : dict
Comment on lines +225 to 226
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add the documnetation for chance_level_kwargs

Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -245,6 +273,7 @@ def from_estimator(
name=name,
ax=ax,
pos_label=pos_label,
plot_chance_level=plot_chance_level,
**kwargs,
)

Expand All @@ -259,6 +288,7 @@ def from_predictions(
pos_label=None,
name=None,
ax=None,
plot_chance_level=True,
Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""Plot ROC curve given the true and predicted values.
Expand Down Expand Up @@ -298,6 +328,9 @@ def from_predictions(
Axes object to plot on. If `None`, a new figure and axes is
created.

plot_chance_level : bool, default=True
Whether to plot the chance level.
Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved

**kwargs : dict
Additional keywords arguments passed to matplotlib `plot` function.

Expand Down Expand Up @@ -348,4 +381,4 @@ def from_predictions(
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label
)

return viz.plot(ax=ax, name=name, **kwargs)
return viz.plot(ax=ax, name=name, plot_chance_level=plot_chance_level, **kwargs)
20 changes: 20 additions & 0 deletions 20 sklearn/metrics/_plot/tests/test_roc_curve_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def data_binary(data):
@pytest.mark.parametrize("with_sample_weight", [True, False])
@pytest.mark.parametrize("drop_intermediate", [True, False])
@pytest.mark.parametrize("with_strings", [True, False])
@pytest.mark.parametrize("plot_chance_level", [True, False])
@pytest.mark.parametrize(
"constructor_name, default_name",
[
Expand All @@ -50,6 +51,7 @@ def test_roc_curve_display_plotting(
with_sample_weight,
drop_intermediate,
with_strings,
plot_chance_level,
constructor_name,
default_name,
):
Expand Down Expand Up @@ -82,6 +84,7 @@ def test_roc_curve_display_plotting(
drop_intermediate=drop_intermediate,
pos_label=pos_label,
alpha=0.8,
plot_chance_level=plot_chance_level,
)
else:
display = RocCurveDisplay.from_predictions(
Expand All @@ -91,6 +94,7 @@ def test_roc_curve_display_plotting(
drop_intermediate=drop_intermediate,
pos_label=pos_label,
alpha=0.8,
plot_chance_level=plot_chance_level,
)

fpr, tpr, _ = roc_curve(
Expand All @@ -114,6 +118,13 @@ def test_roc_curve_display_plotting(
assert isinstance(display.ax_, mpl.axes.Axes)
assert isinstance(display.figure_, mpl.figure.Figure)

if plot_chance_level:
assert isinstance(display.chance_level_, mpl.lines.Line2D)
assert tuple(display.chance_level_.get_xdata()) == (0, 1)
assert tuple(display.chance_level_.get_ydata()) == (0, 1)
else:
assert display.chance_level_ is None

expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})"
assert display.line_.get_label() == expected_label

Expand All @@ -124,6 +135,15 @@ def test_roc_curve_display_plotting(
assert display.ax_.get_ylabel() == expected_ylabel
assert display.ax_.get_xlabel() == expected_xlabel

assert display.ax_.get_xlim() == (0, 1)
assert display.ax_.get_ylim() == (0, 1)
assert display.ax_.get_aspect() == 1

# Check frame styles
for s in ["right", "left", "top", "bottom"]:
assert display.ax_.spines[s].get_linestyle() == (0, (1, 5))
assert display.ax_.spines[s].get_linewidth() <= 0.5

Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize(
"clf",
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.