-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
ENH Improve ROC curves visualization and add option to plot chance level #25972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,9 @@ class RocCurveDisplay: | |
line_ : matplotlib Artist | ||
ROC Curve. | ||
|
||
chance_level_ : matplotlib Artist | ||
The chance level line or None if the chance level is not plotted. | ||
Charlie-XIAO marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
ax_ : matplotlib Axes | ||
Axes with ROC Curve. | ||
|
||
|
@@ -81,7 +84,7 @@ def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=Non | |
self.roc_auc = roc_auc | ||
self.pos_label = pos_label | ||
|
||
def plot(self, ax=None, *, name=None, **kwargs): | ||
def plot(self, ax=None, *, name=None, plot_chance_level=True, **kwargs): | ||
Charlie-XIAO marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Plot visualization. | ||
|
||
Extra keyword arguments will be passed to matplotlib's ``plot``. | ||
|
@@ -96,6 +99,9 @@ def plot(self, ax=None, *, name=None, **kwargs): | |
Name of ROC Curve for labeling. If `None`, use `estimator_name` if | ||
not `None`, otherwise no labeling is shown. | ||
|
||
plot_chance_level : bool, default=True | ||
Whether to plot the chance level. | ||
|
||
Charlie-XIAO marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**kwargs : dict | ||
Keyword arguments to be passed to matplotlib's `plot`. | ||
|
||
|
@@ -123,6 +129,24 @@ def plot(self, ax=None, *, name=None, **kwargs): | |
if ax is None: | ||
fig, ax = plt.subplots() | ||
|
||
# Set limits of axes to [0, 1] and fix aspect ratio to squared | ||
ax.set_xlim((0, 1)) | ||
ax.set_ylim((0, 1)) | ||
ax.set_aspect(1) | ||
Comment on lines
+132
to
+135
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move this change in another PR. We will need an additional entry in the changelog since we are fixing/improving the rendering. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be shared with PR and ROC curve |
||
|
||
# Plot the frame in dotted line, so that the curve can be | ||
# seen better when values are close to 0 or 1 | ||
for s in ["right", "left", "top", "bottom"]: | ||
ax.spines[s].set_linestyle((0, (1, 5))) | ||
ax.spines[s].set_linewidth(0.5) | ||
Comment on lines
+137
to
+141
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. We can postpone the despining. It could also be shared between different type of plots. And we should be able to control it via some keywords. |
||
|
||
if plot_chance_level: | ||
(self.chance_level_,) = ax.plot( | ||
(0, 1), (0, 1), linestyle="dotted", label="Chance level" | ||
Charlie-XIAO marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
else: | ||
self.chance_level_ = None | ||
|
||
(self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs) | ||
info_pos_label = ( | ||
f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" | ||
|
@@ -152,6 +176,7 @@ def from_estimator( | |
pos_label=None, | ||
name=None, | ||
ax=None, | ||
plot_chance_level=True, | ||
Charlie-XIAO marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**kwargs, | ||
): | ||
"""Create a ROC Curve display from an estimator. | ||
|
@@ -195,6 +220,9 @@ def from_estimator( | |
ax : matplotlib axes, default=None | ||
Axes object to plot on. If `None`, a new figure and axes is created. | ||
|
||
plot_chance_level : bool, default=True | ||
Whether to plot the chance level. | ||
Charlie-XIAO marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
**kwargs : dict | ||
Comment on lines
+225
to
226
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add the documnetation for |
||
Keyword arguments to be passed to matplotlib's `plot`. | ||
|
||
|
@@ -245,6 +273,7 @@ def from_estimator( | |
name=name, | ||
ax=ax, | ||
pos_label=pos_label, | ||
plot_chance_level=plot_chance_level, | ||
**kwargs, | ||
) | ||
|
||
|
@@ -259,6 +288,7 @@ def from_predictions( | |
pos_label=None, | ||
name=None, | ||
ax=None, | ||
plot_chance_level=True, | ||
Charlie-XIAO marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**kwargs, | ||
): | ||
"""Plot ROC curve given the true and predicted values. | ||
|
@@ -298,6 +328,9 @@ def from_predictions( | |
Axes object to plot on. If `None`, a new figure and axes is | ||
created. | ||
|
||
plot_chance_level : bool, default=True | ||
Whether to plot the chance level. | ||
Charlie-XIAO marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
**kwargs : dict | ||
Additional keywords arguments passed to matplotlib `plot` function. | ||
|
||
|
@@ -348,4 +381,4 @@ def from_predictions( | |
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label | ||
) | ||
|
||
return viz.plot(ax=ax, name=name, **kwargs) | ||
return viz.plot(ax=ax, name=name, plot_chance_level=plot_chance_level, **kwargs) |
Uh oh!
There was an error while loading. Please reload this page.