Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

MAINT Introduce BinaryClassifierCurveDisplayMixin #25969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 23 additions & 36 deletions 59 sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@
from .utils import (
column_or_1d,
indexable,
check_matplotlib_support,
_safe_indexing,
)
from .utils._response import _get_response_values_binary

from .utils.multiclass import check_classification_targets, type_of_target
from .utils.multiclass import check_classification_targets
from .utils.parallel import delayed, Parallel
from .utils._param_validation import StrOptions, HasMethods, Hidden
from .utils._plotting import _BinaryClassifierCurveDisplayMixin
from .utils.validation import (
_check_fit_params,
_check_pos_label_consistency,
_check_sample_weight,
_num_samples,
check_consistent_length,
Expand All @@ -48,7 +48,6 @@
from .isotonic import IsotonicRegression
from .svm import LinearSVC
from .model_selection import check_cv, cross_val_predict
from .metrics._base import _check_pos_label_consistency


class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down Expand Up @@ -1013,7 +1012,7 @@ def calibration_curve(
return prob_true, prob_pred


class CalibrationDisplay:
class CalibrationDisplay(_BinaryClassifierCurveDisplayMixin):
"""Calibration curve (also known as reliability diagram) visualization.

It is recommended to use
Expand Down Expand Up @@ -1124,13 +1123,8 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
display : :class:`~sklearn.calibration.CalibrationDisplay`
Object that stores computed values.
"""
check_matplotlib_support("CalibrationDisplay.plot")
import matplotlib.pyplot as plt
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)

if ax is None:
fig, ax = plt.subplots()

name = self.estimator_name if name is None else name
info_pos_label = (
f"(Positive class: {self.pos_label})" if self.pos_label is not None else ""
)
Expand All @@ -1141,20 +1135,20 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
line_kwargs.update(**kwargs)

ref_line_label = "Perfectly calibrated"
existing_ref_line = ref_line_label in ax.get_legend_handles_labels()[1]
existing_ref_line = ref_line_label in self.ax_.get_legend_handles_labels()[1]
if ref_line and not existing_ref_line:
ax.plot([0, 1], [0, 1], "k:", label=ref_line_label)
self.line_ = ax.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[0]
self.ax_.plot([0, 1], [0, 1], "k:", label=ref_line_label)
self.line_ = self.ax_.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[
0
]

# We always have to show the legend for at least the reference line
ax.legend(loc="lower right")
self.ax_.legend(loc="lower right")

xlabel = f"Mean predicted probability {info_pos_label}"
ylabel = f"Fraction of positives {info_pos_label}"
ax.set(xlabel=xlabel, ylabel=ylabel)
self.ax_.set(xlabel=xlabel, ylabel=ylabel)

self.ax_ = ax
self.figure_ = ax.figure
return self

@classmethod
Expand Down Expand Up @@ -1260,15 +1254,15 @@ def from_estimator(
>>> disp = CalibrationDisplay.from_estimator(clf, X_test, y_test)
>>> plt.show()
"""
method_name = f"{cls.__name__}.from_estimator"
check_matplotlib_support(method_name)

check_is_fitted(estimator)
y_prob, pos_label = _get_response_values_binary(
estimator, X, response_method="predict_proba", pos_label=pos_label
y_prob, pos_label, name = cls._validate_and_get_response_values(
estimator,
X,
y,
response_method="predict_proba",
pos_label=pos_label,
name=name,
)

name = name if name is not None else estimator.__class__.__name__
return cls.from_predictions(
y,
y_prob,
Expand Down Expand Up @@ -1378,26 +1372,19 @@ def from_predictions(
>>> disp = CalibrationDisplay.from_predictions(y_test, y_prob)
>>> plt.show()
"""
method_name = f"{cls.__name__}.from_predictions"
check_matplotlib_support(method_name)

target_type = type_of_target(y_true)
if target_type != "binary":
raise ValueError(
f"The target y is not binary. Got {target_type} type of target."
)
pos_label_validated, name = cls._validate_from_predictions_params(
y_true, y_prob, sample_weight=None, pos_label=pos_label, name=name
)

prob_true, prob_pred = calibration_curve(
y_true, y_prob, n_bins=n_bins, strategy=strategy, pos_label=pos_label
)
name = "Classifier" if name is None else name
pos_label = _check_pos_label_consistency(pos_label, y_true)

disp = cls(
prob_true=prob_true,
prob_pred=prob_pred,
y_prob=y_prob,
estimator_name=name,
pos_label=pos_label,
pos_label=pos_label_validated,
)
return disp.plot(ax=ax, ref_line=ref_line, **kwargs)
52 changes: 0 additions & 52 deletions 52 sklearn/metrics/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,55 +197,3 @@ def _average_multiclass_ovo_score(binary_metric, y_true, y_score, average="macro
pair_scores[ix] = (a_true_score + b_true_score) / 2

return np.average(pair_scores, weights=prevalence)


def _check_pos_label_consistency(pos_label, y_true):
"""Check if `pos_label` need to be specified or not.

In binary classification, we fix `pos_label=1` if the labels are in the set
{-1, 1} or {0, 1}. Otherwise, we raise an error asking to specify the
`pos_label` parameters.

Parameters
----------
pos_label : int, str or None
The positive label.
y_true : ndarray of shape (n_samples,)
The target vector.

Returns
-------
pos_label : int
If `pos_label` can be inferred, it will be returned.

Raises
------
ValueError
In the case that `y_true` does not have label in {-1, 1} or {0, 1},
it will raise a `ValueError`.
"""
# ensure binary classification if pos_label is not specified
# classes.dtype.kind in ('O', 'U', 'S') is required to avoid
# triggering a FutureWarning by calling np.array_equal(a, b)
# when elements in the two arrays are not comparable.
classes = np.unique(y_true)
if pos_label is None and (
classes.dtype.kind in "OUS"
or not (
np.array_equal(classes, [0, 1])
or np.array_equal(classes, [-1, 1])
or np.array_equal(classes, [0])
or np.array_equal(classes, [-1])
or np.array_equal(classes, [1])
)
):
classes_repr = ", ".join(repr(c) for c in classes)
raise ValueError(
f"y_true takes value in {{{classes_repr}}} and pos_label is not "
"specified: either make y_true take value in {0, 1} or "
"{-1, 1} or pass pos_label explicitly."
)
elif pos_label is None:
pos_label = 1

return pos_label
4 changes: 1 addition & 3 deletions 4 sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@
from ..utils.extmath import _nanaverage
from ..utils.multiclass import unique_labels
from ..utils.multiclass import type_of_target
from ..utils.validation import _num_samples
from ..utils.validation import _check_pos_label_consistency, _num_samples
from ..utils.sparsefuncs import count_nonzero
from ..utils._param_validation import StrOptions, Options, Interval, validate_params
from ..exceptions import UndefinedMetricWarning

from ._base import _check_pos_label_consistency


def _check_zero_division(zero_division):
if isinstance(zero_division, str) and zero_division == "warn":
Expand Down
57 changes: 22 additions & 35 deletions 57 sklearn/metrics/_plot/det_curve.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import scipy as sp

from .. import det_curve
from .._base import _check_pos_label_consistency
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin

from ...utils import check_matplotlib_support
from ...utils._response import _get_response_values_binary


class DetCurveDisplay:
class DetCurveDisplay(_BinaryClassifierCurveDisplayMixin):
"""DET curve visualization.

It is recommend to use :func:`~sklearn.metrics.DetCurveDisplay.from_estimator`
Expand Down Expand Up @@ -163,15 +160,13 @@ def from_estimator(
<...>
>>> plt.show()
"""
check_matplotlib_support(f"{cls.__name__}.from_estimator")

name = estimator.__class__.__name__ if name is None else name

y_pred, pos_label = _get_response_values_binary(
y_pred, pos_label, name = cls._validate_and_get_response_values(
estimator,
X,
response_method,
y,
response_method=response_method,
pos_label=pos_label,
name=name,
)

return cls.from_predictions(
Expand Down Expand Up @@ -259,22 +254,22 @@ def from_predictions(
<...>
>>> plt.show()
"""
check_matplotlib_support(f"{cls.__name__}.from_predictions")
pos_label_validated, name = cls._validate_from_predictions_params(
y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name
)

fpr, fnr, _ = det_curve(
y_true,
y_pred,
pos_label=pos_label,
sample_weight=sample_weight,
)

pos_label = _check_pos_label_consistency(pos_label, y_true)
name = "Classifier" if name is None else name

viz = DetCurveDisplay(
fpr=fpr,
fnr=fnr,
estimator_name=name,
pos_label=pos_label,
pos_label=pos_label_validated,
)

return viz.plot(ax=ax, name=name, **kwargs)
Expand All @@ -300,18 +295,12 @@ def plot(self, ax=None, *, name=None, **kwargs):
display : :class:`~sklearn.metrics.plot.DetCurveDisplay`
Object that stores computed values.
"""
check_matplotlib_support("DetCurveDisplay.plot")
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)

name = self.estimator_name if name is None else name
line_kwargs = {} if name is None else {"label": name}
line_kwargs.update(**kwargs)

import matplotlib.pyplot as plt

if ax is None:
_, ax = plt.subplots()

(self.line_,) = ax.plot(
(self.line_,) = self.ax_.plot(
sp.stats.norm.ppf(self.fpr),
sp.stats.norm.ppf(self.fnr),
**line_kwargs,
Expand All @@ -322,24 +311,22 @@ def plot(self, ax=None, *, name=None, **kwargs):

xlabel = "False Positive Rate" + info_pos_label
ylabel = "False Negative Rate" + info_pos_label
ax.set(xlabel=xlabel, ylabel=ylabel)
self.ax_.set(xlabel=xlabel, ylabel=ylabel)

if "label" in line_kwargs:
ax.legend(loc="lower right")
self.ax_.legend(loc="lower right")

ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999]
tick_locations = sp.stats.norm.ppf(ticks)
tick_labels = [
"{:.0%}".format(s) if (100 * s).is_integer() else "{:.1%}".format(s)
for s in ticks
]
ax.set_xticks(tick_locations)
ax.set_xticklabels(tick_labels)
ax.set_xlim(-3, 3)
ax.set_yticks(tick_locations)
ax.set_yticklabels(tick_labels)
ax.set_ylim(-3, 3)

self.ax_ = ax
self.figure_ = ax.figure
self.ax_.set_xticks(tick_locations)
self.ax_.set_xticklabels(tick_labels)
self.ax_.set_xlim(-3, 3)
self.ax_.set_yticks(tick_locations)
self.ax_.set_yticklabels(tick_labels)
self.ax_.set_ylim(-3, 3)

return self
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.