diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 31f8b67458f78..e4d46555f3761 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -30,16 +30,16 @@ from .utils import ( column_or_1d, indexable, - check_matplotlib_support, _safe_indexing, ) -from .utils._response import _get_response_values_binary -from .utils.multiclass import check_classification_targets, type_of_target +from .utils.multiclass import check_classification_targets from .utils.parallel import delayed, Parallel from .utils._param_validation import StrOptions, HasMethods, Hidden +from .utils._plotting import _BinaryClassifierCurveDisplayMixin from .utils.validation import ( _check_fit_params, + _check_pos_label_consistency, _check_sample_weight, _num_samples, check_consistent_length, @@ -48,7 +48,6 @@ from .isotonic import IsotonicRegression from .svm import LinearSVC from .model_selection import check_cv, cross_val_predict -from .metrics._base import _check_pos_label_consistency class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator): @@ -1013,7 +1012,7 @@ def calibration_curve( return prob_true, prob_pred -class CalibrationDisplay: +class CalibrationDisplay(_BinaryClassifierCurveDisplayMixin): """Calibration curve (also known as reliability diagram) visualization. It is recommended to use @@ -1124,13 +1123,8 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs): display : :class:`~sklearn.calibration.CalibrationDisplay` Object that stores computed values. """ - check_matplotlib_support("CalibrationDisplay.plot") - import matplotlib.pyplot as plt + self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name) - if ax is None: - fig, ax = plt.subplots() - - name = self.estimator_name if name is None else name info_pos_label = ( f"(Positive class: {self.pos_label})" if self.pos_label is not None else "" ) @@ -1141,20 +1135,20 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs): line_kwargs.update(**kwargs) ref_line_label = "Perfectly calibrated" - existing_ref_line = ref_line_label in ax.get_legend_handles_labels()[1] + existing_ref_line = ref_line_label in self.ax_.get_legend_handles_labels()[1] if ref_line and not existing_ref_line: - ax.plot([0, 1], [0, 1], "k:", label=ref_line_label) - self.line_ = ax.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[0] + self.ax_.plot([0, 1], [0, 1], "k:", label=ref_line_label) + self.line_ = self.ax_.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[ + 0 + ] # We always have to show the legend for at least the reference line - ax.legend(loc="lower right") + self.ax_.legend(loc="lower right") xlabel = f"Mean predicted probability {info_pos_label}" ylabel = f"Fraction of positives {info_pos_label}" - ax.set(xlabel=xlabel, ylabel=ylabel) + self.ax_.set(xlabel=xlabel, ylabel=ylabel) - self.ax_ = ax - self.figure_ = ax.figure return self @classmethod @@ -1260,15 +1254,15 @@ def from_estimator( >>> disp = CalibrationDisplay.from_estimator(clf, X_test, y_test) >>> plt.show() """ - method_name = f"{cls.__name__}.from_estimator" - check_matplotlib_support(method_name) - - check_is_fitted(estimator) - y_prob, pos_label = _get_response_values_binary( - estimator, X, response_method="predict_proba", pos_label=pos_label + y_prob, pos_label, name = cls._validate_and_get_response_values( + estimator, + X, + y, + response_method="predict_proba", + pos_label=pos_label, + name=name, ) - name = name if name is not None else estimator.__class__.__name__ return cls.from_predictions( y, y_prob, @@ -1378,26 +1372,19 @@ def from_predictions( >>> disp = CalibrationDisplay.from_predictions(y_test, y_prob) >>> plt.show() """ - method_name = f"{cls.__name__}.from_predictions" - check_matplotlib_support(method_name) - - target_type = type_of_target(y_true) - if target_type != "binary": - raise ValueError( - f"The target y is not binary. Got {target_type} type of target." - ) + pos_label_validated, name = cls._validate_from_predictions_params( + y_true, y_prob, sample_weight=None, pos_label=pos_label, name=name + ) prob_true, prob_pred = calibration_curve( y_true, y_prob, n_bins=n_bins, strategy=strategy, pos_label=pos_label ) - name = "Classifier" if name is None else name - pos_label = _check_pos_label_consistency(pos_label, y_true) disp = cls( prob_true=prob_true, prob_pred=prob_pred, y_prob=y_prob, estimator_name=name, - pos_label=pos_label, + pos_label=pos_label_validated, ) return disp.plot(ax=ax, ref_line=ref_line, **kwargs) diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index dd0258f600ccc..53ff14b039e0c 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -197,55 +197,3 @@ def _average_multiclass_ovo_score(binary_metric, y_true, y_score, average="macro pair_scores[ix] = (a_true_score + b_true_score) / 2 return np.average(pair_scores, weights=prevalence) - - -def _check_pos_label_consistency(pos_label, y_true): - """Check if `pos_label` need to be specified or not. - - In binary classification, we fix `pos_label=1` if the labels are in the set - {-1, 1} or {0, 1}. Otherwise, we raise an error asking to specify the - `pos_label` parameters. - - Parameters - ---------- - pos_label : int, str or None - The positive label. - y_true : ndarray of shape (n_samples,) - The target vector. - - Returns - ------- - pos_label : int - If `pos_label` can be inferred, it will be returned. - - Raises - ------ - ValueError - In the case that `y_true` does not have label in {-1, 1} or {0, 1}, - it will raise a `ValueError`. - """ - # ensure binary classification if pos_label is not specified - # classes.dtype.kind in ('O', 'U', 'S') is required to avoid - # triggering a FutureWarning by calling np.array_equal(a, b) - # when elements in the two arrays are not comparable. - classes = np.unique(y_true) - if pos_label is None and ( - classes.dtype.kind in "OUS" - or not ( - np.array_equal(classes, [0, 1]) - or np.array_equal(classes, [-1, 1]) - or np.array_equal(classes, [0]) - or np.array_equal(classes, [-1]) - or np.array_equal(classes, [1]) - ) - ): - classes_repr = ", ".join(repr(c) for c in classes) - raise ValueError( - f"y_true takes value in {{{classes_repr}}} and pos_label is not " - "specified: either make y_true take value in {0, 1} or " - "{-1, 1} or pass pos_label explicitly." - ) - elif pos_label is None: - pos_label = 1 - - return pos_label diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 67c34e92cf8f3..8e203a48c967b 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -40,13 +40,11 @@ from ..utils.extmath import _nanaverage from ..utils.multiclass import unique_labels from ..utils.multiclass import type_of_target -from ..utils.validation import _num_samples +from ..utils.validation import _check_pos_label_consistency, _num_samples from ..utils.sparsefuncs import count_nonzero from ..utils._param_validation import StrOptions, Options, Interval, validate_params from ..exceptions import UndefinedMetricWarning -from ._base import _check_pos_label_consistency - def _check_zero_division(zero_division): if isinstance(zero_division, str) and zero_division == "warn": diff --git a/sklearn/metrics/_plot/det_curve.py b/sklearn/metrics/_plot/det_curve.py index f9832fed41847..4db7b10e0d8ac 100644 --- a/sklearn/metrics/_plot/det_curve.py +++ b/sklearn/metrics/_plot/det_curve.py @@ -1,13 +1,10 @@ import scipy as sp from .. import det_curve -from .._base import _check_pos_label_consistency +from ...utils._plotting import _BinaryClassifierCurveDisplayMixin -from ...utils import check_matplotlib_support -from ...utils._response import _get_response_values_binary - -class DetCurveDisplay: +class DetCurveDisplay(_BinaryClassifierCurveDisplayMixin): """DET curve visualization. It is recommend to use :func:`~sklearn.metrics.DetCurveDisplay.from_estimator` @@ -163,15 +160,13 @@ def from_estimator( <...> >>> plt.show() """ - check_matplotlib_support(f"{cls.__name__}.from_estimator") - - name = estimator.__class__.__name__ if name is None else name - - y_pred, pos_label = _get_response_values_binary( + y_pred, pos_label, name = cls._validate_and_get_response_values( estimator, X, - response_method, + y, + response_method=response_method, pos_label=pos_label, + name=name, ) return cls.from_predictions( @@ -259,7 +254,10 @@ def from_predictions( <...> >>> plt.show() """ - check_matplotlib_support(f"{cls.__name__}.from_predictions") + pos_label_validated, name = cls._validate_from_predictions_params( + y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name + ) + fpr, fnr, _ = det_curve( y_true, y_pred, @@ -267,14 +265,11 @@ def from_predictions( sample_weight=sample_weight, ) - pos_label = _check_pos_label_consistency(pos_label, y_true) - name = "Classifier" if name is None else name - viz = DetCurveDisplay( fpr=fpr, fnr=fnr, estimator_name=name, - pos_label=pos_label, + pos_label=pos_label_validated, ) return viz.plot(ax=ax, name=name, **kwargs) @@ -300,18 +295,12 @@ def plot(self, ax=None, *, name=None, **kwargs): display : :class:`~sklearn.metrics.plot.DetCurveDisplay` Object that stores computed values. """ - check_matplotlib_support("DetCurveDisplay.plot") + self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name) - name = self.estimator_name if name is None else name line_kwargs = {} if name is None else {"label": name} line_kwargs.update(**kwargs) - import matplotlib.pyplot as plt - - if ax is None: - _, ax = plt.subplots() - - (self.line_,) = ax.plot( + (self.line_,) = self.ax_.plot( sp.stats.norm.ppf(self.fpr), sp.stats.norm.ppf(self.fnr), **line_kwargs, @@ -322,10 +311,10 @@ def plot(self, ax=None, *, name=None, **kwargs): xlabel = "False Positive Rate" + info_pos_label ylabel = "False Negative Rate" + info_pos_label - ax.set(xlabel=xlabel, ylabel=ylabel) + self.ax_.set(xlabel=xlabel, ylabel=ylabel) if "label" in line_kwargs: - ax.legend(loc="lower right") + self.ax_.legend(loc="lower right") ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999] tick_locations = sp.stats.norm.ppf(ticks) @@ -333,13 +322,11 @@ def plot(self, ax=None, *, name=None, **kwargs): "{:.0%}".format(s) if (100 * s).is_integer() else "{:.1%}".format(s) for s in ticks ] - ax.set_xticks(tick_locations) - ax.set_xticklabels(tick_labels) - ax.set_xlim(-3, 3) - ax.set_yticks(tick_locations) - ax.set_yticklabels(tick_labels) - ax.set_ylim(-3, 3) - - self.ax_ = ax - self.figure_ = ax.figure + self.ax_.set_xticks(tick_locations) + self.ax_.set_xticklabels(tick_labels) + self.ax_.set_xlim(-3, 3) + self.ax_.set_yticks(tick_locations) + self.ax_.set_yticklabels(tick_labels) + self.ax_.set_ylim(-3, 3) + return self diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 209f4dd0c3862..f99001d3dce9c 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,13 +1,9 @@ from .. import average_precision_score from .. import precision_recall_curve -from .._base import _check_pos_label_consistency -from .._classification import check_consistent_length +from ...utils._plotting import _BinaryClassifierCurveDisplayMixin -from ...utils import check_matplotlib_support -from ...utils._response import _get_response_values_binary - -class PrecisionRecallDisplay: +class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin): """Precision Recall visualization. It is recommend to use @@ -141,9 +137,7 @@ def plot(self, ax=None, *, name=None, **kwargs): `drawstyle="default"`. However, the curve will not be strictly consistent with the reported average precision. """ - check_matplotlib_support("PrecisionRecallDisplay.plot") - - name = self.estimator_name if name is None else name + self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name) line_kwargs = {"drawstyle": "steps-post"} if self.average_precision is not None and name is not None: @@ -154,25 +148,18 @@ def plot(self, ax=None, *, name=None, **kwargs): line_kwargs["label"] = name line_kwargs.update(**kwargs) - import matplotlib.pyplot as plt - - if ax is None: - fig, ax = plt.subplots() - - (self.line_,) = ax.plot(self.recall, self.precision, **line_kwargs) + (self.line_,) = self.ax_.plot(self.recall, self.precision, **line_kwargs) info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" ) xlabel = "Recall" + info_pos_label ylabel = "Precision" + info_pos_label - ax.set(xlabel=xlabel, ylabel=ylabel) + self.ax_.set(xlabel=xlabel, ylabel=ylabel) if "label" in line_kwargs: - ax.legend(loc="lower left") + self.ax_.legend(loc="lower left") - self.ax_ = ax - self.figure_ = ax.figure return self @classmethod @@ -273,18 +260,15 @@ def from_estimator( <...> >>> plt.show() """ - method_name = f"{cls.__name__}.from_estimator" - check_matplotlib_support(method_name) - - y_pred, pos_label = _get_response_values_binary( + y_pred, pos_label, name = cls._validate_and_get_response_values( estimator, X, - response_method, + y, + response_method=response_method, pos_label=pos_label, + name=name, ) - name = name if name is not None else estimator.__class__.__name__ - return cls.from_predictions( y, y_pred, @@ -382,10 +366,9 @@ def from_predictions( <...> >>> plt.show() """ - check_matplotlib_support(f"{cls.__name__}.from_predictions") - - check_consistent_length(y_true, y_pred, sample_weight) - pos_label = _check_pos_label_consistency(pos_label, y_true) + pos_label, name = cls._validate_from_predictions_params( + y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name + ) precision, recall, _ = precision_recall_curve( y_true, @@ -398,8 +381,6 @@ def from_predictions( y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight ) - name = name if name is not None else "Classifier" - viz = PrecisionRecallDisplay( precision=precision, recall=recall, diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index e7158855cdcb4..8765561d1e477 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,12 +1,9 @@ from .. import auc from .. import roc_curve -from .._base import _check_pos_label_consistency +from ...utils._plotting import _BinaryClassifierCurveDisplayMixin -from ...utils import check_matplotlib_support -from ...utils._response import _get_response_values_binary - -class RocCurveDisplay: +class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): """ROC Curve visualization. It is recommend to use @@ -128,9 +125,7 @@ def plot( display : :class:`~sklearn.metrics.plot.RocCurveDisplay` Object that stores computed values. """ - check_matplotlib_support("RocCurveDisplay.plot") - - name = self.estimator_name if name is None else name + self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name) line_kwargs = {} if self.roc_auc is not None and name is not None: @@ -151,30 +146,25 @@ def plot( if chance_level_kw is not None: chance_level_line_kw.update(**chance_level_kw) - import matplotlib.pyplot as plt - - if ax is None: - fig, ax = plt.subplots() - - (self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs) + (self.line_,) = self.ax_.plot(self.fpr, self.tpr, **line_kwargs) info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" ) xlabel = "False Positive Rate" + info_pos_label ylabel = "True Positive Rate" + info_pos_label - ax.set(xlabel=xlabel, ylabel=ylabel) + self.ax_.set(xlabel=xlabel, ylabel=ylabel) if plot_chance_level: - (self.chance_level_,) = ax.plot((0, 1), (0, 1), **chance_level_line_kw) + (self.chance_level_,) = self.ax_.plot( + (0, 1), (0, 1), **chance_level_line_kw + ) else: self.chance_level_ = None - if "label" in line_kwargs: - ax.legend(loc="lower right") + if "label" in line_kwargs or "label" in chance_level_line_kw: + self.ax_.legend(loc="lower right") - self.ax_ = ax - self.figure_ = ax.figure return self @classmethod @@ -277,15 +267,13 @@ def from_estimator( <...> >>> plt.show() """ - check_matplotlib_support(f"{cls.__name__}.from_estimator") - - name = estimator.__class__.__name__ if name is None else name - - y_pred, pos_label = _get_response_values_binary( + y_pred, pos_label, name = cls._validate_and_get_response_values( estimator, X, + y, response_method=response_method, pos_label=pos_label, + name=name, ) return cls.from_predictions( @@ -396,7 +384,9 @@ def from_predictions( <...> >>> plt.show() """ - check_matplotlib_support(f"{cls.__name__}.from_predictions") + pos_label_validated, name = cls._validate_from_predictions_params( + y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name + ) fpr, tpr, _ = roc_curve( y_true, @@ -407,11 +397,12 @@ def from_predictions( ) roc_auc = auc(fpr, tpr) - name = "Classifier" if name is None else name - pos_label = _check_pos_label_consistency(pos_label, y_true) - viz = RocCurveDisplay( - fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label + fpr=fpr, + tpr=tpr, + roc_auc=roc_auc, + estimator_name=name, + pos_label=pos_label_validated, ) return viz.plot( diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 27730893bb05c..fde87e2949d0b 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from sklearn.base import ClassifierMixin, clone @@ -7,8 +8,9 @@ from sklearn.linear_model import LogisticRegression from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor +from sklearn.calibration import CalibrationDisplay from sklearn.metrics import ( DetCurveDisplay, PrecisionRecallDisplay, @@ -28,18 +30,57 @@ def data_binary(data): @pytest.mark.parametrize( - "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] + "Display", + [CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay], ) -def test_display_curve_error_non_binary(pyplot, data, Display): +def test_display_curve_error_classifier(pyplot, data, data_binary, Display): """Check that a proper error is raised when only binary classification is supported.""" X, y = data + X_binary, y_binary = data_binary clf = DecisionTreeClassifier().fit(X, y) + # Case 1: multiclass classifier with multiclass target msg = "Expected 'estimator' to be a binary classifier. Got 3 classes instead." with pytest.raises(ValueError, match=msg): Display.from_estimator(clf, X, y) + # Case 2: multiclass classifier with binary target + with pytest.raises(ValueError, match=msg): + Display.from_estimator(clf, X_binary, y_binary) + + # Case 3: binary classifier with multiclass target + clf = DecisionTreeClassifier().fit(X_binary, y_binary) + msg = "The target y is not binary. Got multiclass type of target." + with pytest.raises(ValueError, match=msg): + Display.from_estimator(clf, X, y) + + +@pytest.mark.parametrize( + "Display", + [CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay], +) +def test_display_curve_error_regression(pyplot, data_binary, Display): + """Check that we raise an error with regressor.""" + + # Case 1: regressor + X, y = data_binary + regressor = DecisionTreeRegressor().fit(X, y) + + msg = "Expected 'estimator' to be a binary classifier. Got DecisionTreeRegressor" + with pytest.raises(ValueError, match=msg): + Display.from_estimator(regressor, X, y) + + # Case 2: regression target + classifier = DecisionTreeClassifier().fit(X, y) + # Force `y_true` to be seen as a regression problem + y = y + 0.5 + msg = "The target y is not binary. Got continuous type of target." + with pytest.raises(ValueError, match=msg): + Display.from_estimator(classifier, X, y) + with pytest.raises(ValueError, match=msg): + Display.from_predictions(y, regressor.fit(X, y).predict(X)) + @pytest.mark.parametrize( "response_method, msg", @@ -148,3 +189,36 @@ def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): disp = Display.from_estimator(model, X, y) assert model.__class__.__name__ in disp.line_.get_label() assert disp.estimator_name == model.__class__.__name__ + + +@pytest.mark.parametrize( + "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] +) +def test_display_curve_n_samples_consistency(pyplot, data_binary, Display): + """Check the error raised when `y_pred` or `sample_weight` have inconsistent + length.""" + X, y = data_binary + classifier = DecisionTreeClassifier().fit(X, y) + + msg = "Found input variables with inconsistent numbers of samples" + with pytest.raises(ValueError, match=msg): + Display.from_estimator(classifier, X[:-2], y) + with pytest.raises(ValueError, match=msg): + Display.from_estimator(classifier, X, y[:-2]) + with pytest.raises(ValueError, match=msg): + Display.from_estimator(classifier, X, y, sample_weight=np.ones(X.shape[0] - 2)) + + +@pytest.mark.parametrize( + "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] +) +def test_display_curve_error_pos_label(pyplot, data_binary, Display): + """Check consistence of error message when `pos_label` should be specified.""" + X, y = data_binary + y = y + 10 + + classifier = DecisionTreeClassifier().fit(X, y) + y_pred = classifier.predict_proba(X)[:, -1] + msg = r"y_true takes value in {10, 11} and pos_label is not specified" + with pytest.raises(ValueError, match=msg): + Display.from_predictions(y, y_pred) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index e7e1917c79776..7d963d8b87f0a 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -9,7 +9,6 @@ from sklearn.model_selection import train_test_split from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from sklearn.svm import SVC, SVR from sklearn.utils import shuffle from sklearn.metrics import PrecisionRecallDisplay @@ -21,48 +20,6 @@ ) -def test_precision_recall_display_validation(pyplot): - """Check that we raise the proper error when validating parameters.""" - X, y = make_classification( - n_samples=100, n_informative=5, n_classes=5, random_state=0 - ) - - with pytest.raises(NotFittedError): - PrecisionRecallDisplay.from_estimator(SVC(), X, y) - - regressor = SVR().fit(X, y) - y_pred_regressor = regressor.predict(X) - classifier = SVC(probability=True).fit(X, y) - y_pred_classifier = classifier.predict_proba(X)[:, -1] - - err_msg = "Expected 'estimator' to be a binary classifier. Got SVR instead." - with pytest.raises(ValueError, match=err_msg): - PrecisionRecallDisplay.from_estimator(regressor, X, y) - - err_msg = "Expected 'estimator' to be a binary classifier." - with pytest.raises(ValueError, match=err_msg): - PrecisionRecallDisplay.from_estimator(classifier, X, y) - - err_msg = "{} format is not supported" - with pytest.raises(ValueError, match=err_msg.format("continuous")): - # Force `y_true` to be seen as a regression problem - PrecisionRecallDisplay.from_predictions(y + 0.5, y_pred_classifier, pos_label=1) - with pytest.raises(ValueError, match=err_msg.format("multiclass")): - PrecisionRecallDisplay.from_predictions(y, y_pred_regressor, pos_label=1) - - err_msg = "Found input variables with inconsistent numbers of samples" - with pytest.raises(ValueError, match=err_msg): - PrecisionRecallDisplay.from_predictions(y, y_pred_classifier[::2]) - - X, y = make_classification(n_classes=2, n_samples=50, random_state=0) - y += 10 - classifier.fit(X, y) - y_pred_classifier = classifier.predict_proba(X)[:, -1] - err_msg = r"y_true takes value in {10, 11} and pos_label is not specified" - with pytest.raises(ValueError, match=err_msg): - PrecisionRecallDisplay.from_predictions(y, y_pred_classifier) - - @pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @pytest.mark.parametrize("drop_intermediate", [True, False]) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 6693a6391919e..c45994d110b97 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -29,7 +29,7 @@ from ..utils import assert_all_finite from ..utils import check_consistent_length -from ..utils.validation import _check_sample_weight +from ..utils.validation import _check_pos_label_consistency, _check_sample_weight from ..utils import column_or_1d, check_array from ..utils.multiclass import type_of_target from ..utils.extmath import stable_cumsum @@ -39,11 +39,7 @@ from ..preprocessing import label_binarize from ..utils._encode import _encode, _unique -from ._base import ( - _average_binary_score, - _average_multiclass_ovo_score, - _check_pos_label_consistency, -) +from ._base import _average_binary_score, _average_multiclass_ovo_score @validate_params({"x": ["array-like"], "y": ["array-like"]}) diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index fff774c3fc490..679ec6ed52c15 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -25,7 +25,7 @@ RandomForestClassifier, VotingClassifier, ) -from sklearn.linear_model import LogisticRegression, LinearRegression +from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier from sklearn.svm import LinearSVC from sklearn.pipeline import Pipeline, make_pipeline @@ -595,42 +595,6 @@ def iris_data_binary(iris_data): return X[y < 2], y[y < 2] -def test_calibration_display_validation(pyplot, iris_data, iris_data_binary): - X, y = iris_data - X_binary, y_binary = iris_data_binary - - reg = LinearRegression().fit(X, y) - msg = "Expected 'estimator' to be a binary classifier. Got LinearRegression" - with pytest.raises(ValueError, match=msg): - CalibrationDisplay.from_estimator(reg, X, y) - - clf = LinearSVC().fit(X_binary, y_binary) - msg = "has none of the following attributes: predict_proba." - with pytest.raises(AttributeError, match=msg): - CalibrationDisplay.from_estimator(clf, X, y) - - clf = LogisticRegression() - with pytest.raises(NotFittedError): - CalibrationDisplay.from_estimator(clf, X, y) - - -@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) -def test_calibration_display_non_binary(pyplot, iris_data, constructor_name): - X, y = iris_data - clf = DecisionTreeClassifier() - clf.fit(X, y) - y_prob = clf.predict_proba(X) - - if constructor_name == "from_estimator": - msg = "to be a binary classifier. Got 3 classes instead." - with pytest.raises(ValueError, match=msg): - CalibrationDisplay.from_estimator(clf, X, y) - else: - msg = "The target y is not binary. Got multiclass type of target." - with pytest.raises(ValueError, match=msg): - CalibrationDisplay.from_predictions(y, y_prob) - - @pytest.mark.parametrize("n_bins", [5, 10]) @pytest.mark.parametrize("strategy", ["uniform", "quantile"]) def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy): diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py new file mode 100644 index 0000000000000..cc301b509e386 --- /dev/null +++ b/sklearn/utils/_plotting.py @@ -0,0 +1,58 @@ +from . import check_consistent_length, check_matplotlib_support +from .multiclass import type_of_target +from .validation import _check_pos_label_consistency +from ._response import _get_response_values_binary + + +class _BinaryClassifierCurveDisplayMixin: + """Mixin class to be used in Displays requiring a binary classifier. + + The aim of this class is to centralize some validations regarding the estimator and + the target and gather the response of the estimator. + """ + + def _validate_plot_params(self, *, ax=None, name=None): + check_matplotlib_support(f"{self.__class__.__name__}.plot") + import matplotlib.pyplot as plt + + if ax is None: + _, ax = plt.subplots() + + name = self.estimator_name if name is None else name + return ax, ax.figure, name + + @classmethod + def _validate_and_get_response_values( + cls, estimator, X, y, *, response_method="auto", pos_label=None, name=None + ): + check_matplotlib_support(f"{cls.__name__}.from_estimator") + + name = estimator.__class__.__name__ if name is None else name + + y_pred, pos_label = _get_response_values_binary( + estimator, + X, + response_method=response_method, + pos_label=pos_label, + ) + + return y_pred, pos_label, name + + @classmethod + def _validate_from_predictions_params( + cls, y_true, y_pred, *, sample_weight=None, pos_label=None, name=None + ): + check_matplotlib_support(f"{cls.__name__}.from_predictions") + + if type_of_target(y_true) != "binary": + raise ValueError( + f"The target y is not binary. Got {type_of_target(y_true)} type of" + " target." + ) + + check_consistent_length(y_true, y_pred, sample_weight) + pos_label = _check_pos_label_consistency(pos_label, y_true) + + name = name if name is not None else "Classifier" + + return pos_label, name diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 2ae5c6a42d172..dd9693ed9e1ae 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -2145,3 +2145,55 @@ def _check_monotonic_cst(estimator, monotonic_cst=None): f"X has {estimator.n_features_in_} features." ) return monotonic_cst + + +def _check_pos_label_consistency(pos_label, y_true): + """Check if `pos_label` need to be specified or not. + + In binary classification, we fix `pos_label=1` if the labels are in the set + {-1, 1} or {0, 1}. Otherwise, we raise an error asking to specify the + `pos_label` parameters. + + Parameters + ---------- + pos_label : int, str or None + The positive label. + y_true : ndarray of shape (n_samples,) + The target vector. + + Returns + ------- + pos_label : int + If `pos_label` can be inferred, it will be returned. + + Raises + ------ + ValueError + In the case that `y_true` does not have label in {-1, 1} or {0, 1}, + it will raise a `ValueError`. + """ + # ensure binary classification if pos_label is not specified + # classes.dtype.kind in ('O', 'U', 'S') is required to avoid + # triggering a FutureWarning by calling np.array_equal(a, b) + # when elements in the two arrays are not comparable. + classes = np.unique(y_true) + if pos_label is None and ( + classes.dtype.kind in "OUS" + or not ( + np.array_equal(classes, [0, 1]) + or np.array_equal(classes, [-1, 1]) + or np.array_equal(classes, [0]) + or np.array_equal(classes, [-1]) + or np.array_equal(classes, [1]) + ) + ): + classes_repr = ", ".join(repr(c) for c in classes) + raise ValueError( + f"y_true takes value in {{{classes_repr}}} and pos_label is not " + "specified: either make y_true take value in {0, 1} or " + "{-1, 1} or pass pos_label explicitly." + ) + elif pos_label is None: + pos_label = 1 + + return pos_label