Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8d7c244

Browse filesBrowse files
authored
Merge branch 'scikit-learn:main' into update-scikit-learn
2 parents b6b2d72 + ba46b65 commit 8d7c244
Copy full SHA for 8d7c244
Expand file treeCollapse file tree

30 files changed

+526
-346
lines changed

‎doc/modules/classes.rst

Copy file name to clipboardExpand all lines: doc/modules/classes.rst
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,7 @@ See the :ref:`visualizations` section of the user guide for further details.
11221122

11231123
.. autosummary::
11241124
:toctree: generated/
1125-
:template: display.rst
1125+
:template: display_all_class_methods.rst
11261126

11271127
metrics.ConfusionMatrixDisplay
11281128
metrics.DetCurveDisplay

‎sklearn/calibration.py

Copy file name to clipboardExpand all lines: sklearn/calibration.py
+23-36Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@
3030
from .utils import (
3131
column_or_1d,
3232
indexable,
33-
check_matplotlib_support,
3433
_safe_indexing,
3534
)
36-
from .utils._response import _get_response_values_binary
3735

38-
from .utils.multiclass import check_classification_targets, type_of_target
36+
from .utils.multiclass import check_classification_targets
3937
from .utils.parallel import delayed, Parallel
4038
from .utils._param_validation import StrOptions, HasMethods, Hidden
39+
from .utils._plotting import _BinaryClassifierCurveDisplayMixin
4140
from .utils.validation import (
4241
_check_fit_params,
42+
_check_pos_label_consistency,
4343
_check_sample_weight,
4444
_num_samples,
4545
check_consistent_length,
@@ -48,7 +48,6 @@
4848
from .isotonic import IsotonicRegression
4949
from .svm import LinearSVC
5050
from .model_selection import check_cv, cross_val_predict
51-
from .metrics._base import _check_pos_label_consistency
5251

5352

5453
class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
@@ -1013,7 +1012,7 @@ def calibration_curve(
10131012
return prob_true, prob_pred
10141013

10151014

1016-
class CalibrationDisplay:
1015+
class CalibrationDisplay(_BinaryClassifierCurveDisplayMixin):
10171016
"""Calibration curve (also known as reliability diagram) visualization.
10181017
10191018
It is recommended to use
@@ -1124,13 +1123,8 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
11241123
display : :class:`~sklearn.calibration.CalibrationDisplay`
11251124
Object that stores computed values.
11261125
"""
1127-
check_matplotlib_support("CalibrationDisplay.plot")
1128-
import matplotlib.pyplot as plt
1126+
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
11291127

1130-
if ax is None:
1131-
fig, ax = plt.subplots()
1132-
1133-
name = self.estimator_name if name is None else name
11341128
info_pos_label = (
11351129
f"(Positive class: {self.pos_label})" if self.pos_label is not None else ""
11361130
)
@@ -1141,20 +1135,20 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
11411135
line_kwargs.update(**kwargs)
11421136

11431137
ref_line_label = "Perfectly calibrated"
1144-
existing_ref_line = ref_line_label in ax.get_legend_handles_labels()[1]
1138+
existing_ref_line = ref_line_label in self.ax_.get_legend_handles_labels()[1]
11451139
if ref_line and not existing_ref_line:
1146-
ax.plot([0, 1], [0, 1], "k:", label=ref_line_label)
1147-
self.line_ = ax.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[0]
1140+
self.ax_.plot([0, 1], [0, 1], "k:", label=ref_line_label)
1141+
self.line_ = self.ax_.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[
1142+
0
1143+
]
11481144

11491145
# We always have to show the legend for at least the reference line
1150-
ax.legend(loc="lower right")
1146+
self.ax_.legend(loc="lower right")
11511147

11521148
xlabel = f"Mean predicted probability {info_pos_label}"
11531149
ylabel = f"Fraction of positives {info_pos_label}"
1154-
ax.set(xlabel=xlabel, ylabel=ylabel)
1150+
self.ax_.set(xlabel=xlabel, ylabel=ylabel)
11551151

1156-
self.ax_ = ax
1157-
self.figure_ = ax.figure
11581152
return self
11591153

11601154
@classmethod
@@ -1260,15 +1254,15 @@ def from_estimator(
12601254
>>> disp = CalibrationDisplay.from_estimator(clf, X_test, y_test)
12611255
>>> plt.show()
12621256
"""
1263-
method_name = f"{cls.__name__}.from_estimator"
1264-
check_matplotlib_support(method_name)
1265-
1266-
check_is_fitted(estimator)
1267-
y_prob, pos_label = _get_response_values_binary(
1268-
estimator, X, response_method="predict_proba", pos_label=pos_label
1257+
y_prob, pos_label, name = cls._validate_and_get_response_values(
1258+
estimator,
1259+
X,
1260+
y,
1261+
response_method="predict_proba",
1262+
pos_label=pos_label,
1263+
name=name,
12691264
)
12701265

1271-
name = name if name is not None else estimator.__class__.__name__
12721266
return cls.from_predictions(
12731267
y,
12741268
y_prob,
@@ -1378,26 +1372,19 @@ def from_predictions(
13781372
>>> disp = CalibrationDisplay.from_predictions(y_test, y_prob)
13791373
>>> plt.show()
13801374
"""
1381-
method_name = f"{cls.__name__}.from_predictions"
1382-
check_matplotlib_support(method_name)
1383-
1384-
target_type = type_of_target(y_true)
1385-
if target_type != "binary":
1386-
raise ValueError(
1387-
f"The target y is not binary. Got {target_type} type of target."
1388-
)
1375+
pos_label_validated, name = cls._validate_from_predictions_params(
1376+
y_true, y_prob, sample_weight=None, pos_label=pos_label, name=name
1377+
)
13891378

13901379
prob_true, prob_pred = calibration_curve(
13911380
y_true, y_prob, n_bins=n_bins, strategy=strategy, pos_label=pos_label
13921381
)
1393-
name = "Classifier" if name is None else name
1394-
pos_label = _check_pos_label_consistency(pos_label, y_true)
13951382

13961383
disp = cls(
13971384
prob_true=prob_true,
13981385
prob_pred=prob_pred,
13991386
y_prob=y_prob,
14001387
estimator_name=name,
1401-
pos_label=pos_label,
1388+
pos_label=pos_label_validated,
14021389
)
14031390
return disp.plot(ax=ax, ref_line=ref_line, **kwargs)

‎sklearn/compose/_column_transformer.py

Copy file name to clipboardExpand all lines: sklearn/compose/_column_transformer.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,8 @@ def _get_transformer_list(estimators):
936936
return transformer_list
937937

938938

939+
# This function is not validated using validate_params because
940+
# it's just a factory for ColumnTransformer.
939941
def make_column_transformer(
940942
*transformers,
941943
remainder="drop",

‎sklearn/discriminant_analysis.py

Copy file name to clipboardExpand all lines: sklearn/discriminant_analysis.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def fit(self, X, y):
640640
intercept_ = xp.asarray(
641641
self.intercept_[1] - self.intercept_[0], dtype=X.dtype
642642
)
643-
self.intercept_ = xp.reshape(intercept_, 1)
643+
self.intercept_ = xp.reshape(intercept_, (1,))
644644
self._n_features_out = self._max_components
645645
return self
646646

‎sklearn/impute/_base.py

Copy file name to clipboardExpand all lines: sklearn/impute/_base.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from scipy import sparse as sp
1212

1313
from ..base import BaseEstimator, TransformerMixin
14-
from ..utils._param_validation import StrOptions, Hidden
14+
from ..utils._param_validation import StrOptions, Hidden, MissingValues
1515
from ..utils.fixes import _mode
1616
from ..utils.sparsefuncs import _get_median
1717
from ..utils.validation import check_is_fitted
@@ -78,7 +78,7 @@ class _BaseImputer(TransformerMixin, BaseEstimator):
7878
"""
7979

8080
_parameter_constraints: dict = {
81-
"missing_values": ["missing_values"],
81+
"missing_values": [MissingValues()],
8282
"add_indicator": ["boolean"],
8383
"keep_empty_features": ["boolean"],
8484
}
@@ -800,7 +800,7 @@ class MissingIndicator(TransformerMixin, BaseEstimator):
800800
"""
801801

802802
_parameter_constraints: dict = {
803-
"missing_values": [numbers.Real, numbers.Integral, str, None],
803+
"missing_values": [MissingValues()],
804804
"features": [StrOptions({"missing-only", "all"})],
805805
"sparse": ["boolean", StrOptions({"auto"})],
806806
"error_on_new": ["boolean"],

‎sklearn/linear_model/_base.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_base.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def decision_function(self, X):
399399

400400
X = self._validate_data(X, accept_sparse="csr", reset=False)
401401
scores = safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_
402-
return xp.reshape(scores, -1) if scores.shape[1] == 1 else scores
402+
return xp.reshape(scores, (-1,)) if scores.shape[1] == 1 else scores
403403

404404
def predict(self, X):
405405
"""

‎sklearn/metrics/_base.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_base.py
-52Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -197,55 +197,3 @@ def _average_multiclass_ovo_score(binary_metric, y_true, y_score, average="macro
197197
pair_scores[ix] = (a_true_score + b_true_score) / 2
198198

199199
return np.average(pair_scores, weights=prevalence)
200-
201-
202-
def _check_pos_label_consistency(pos_label, y_true):
203-
"""Check if `pos_label` need to be specified or not.
204-
205-
In binary classification, we fix `pos_label=1` if the labels are in the set
206-
{-1, 1} or {0, 1}. Otherwise, we raise an error asking to specify the
207-
`pos_label` parameters.
208-
209-
Parameters
210-
----------
211-
pos_label : int, str or None
212-
The positive label.
213-
y_true : ndarray of shape (n_samples,)
214-
The target vector.
215-
216-
Returns
217-
-------
218-
pos_label : int
219-
If `pos_label` can be inferred, it will be returned.
220-
221-
Raises
222-
------
223-
ValueError
224-
In the case that `y_true` does not have label in {-1, 1} or {0, 1},
225-
it will raise a `ValueError`.
226-
"""
227-
# ensure binary classification if pos_label is not specified
228-
# classes.dtype.kind in ('O', 'U', 'S') is required to avoid
229-
# triggering a FutureWarning by calling np.array_equal(a, b)
230-
# when elements in the two arrays are not comparable.
231-
classes = np.unique(y_true)
232-
if pos_label is None and (
233-
classes.dtype.kind in "OUS"
234-
or not (
235-
np.array_equal(classes, [0, 1])
236-
or np.array_equal(classes, [-1, 1])
237-
or np.array_equal(classes, [0])
238-
or np.array_equal(classes, [-1])
239-
or np.array_equal(classes, [1])
240-
)
241-
):
242-
classes_repr = ", ".join(repr(c) for c in classes)
243-
raise ValueError(
244-
f"y_true takes value in {{{classes_repr}}} and pos_label is not "
245-
"specified: either make y_true take value in {0, 1} or "
246-
"{-1, 1} or pass pos_label explicitly."
247-
)
248-
elif pos_label is None:
249-
pos_label = 1
250-
251-
return pos_label

‎sklearn/metrics/_classification.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_classification.py
+1-3Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,11 @@
4040
from ..utils.extmath import _nanaverage
4141
from ..utils.multiclass import unique_labels
4242
from ..utils.multiclass import type_of_target
43-
from ..utils.validation import _num_samples
43+
from ..utils.validation import _check_pos_label_consistency, _num_samples
4444
from ..utils.sparsefuncs import count_nonzero
4545
from ..utils._param_validation import StrOptions, Options, Interval, validate_params
4646
from ..exceptions import UndefinedMetricWarning
4747

48-
from ._base import _check_pos_label_consistency
49-
5048

5149
def _check_zero_division(zero_division):
5250
if isinstance(zero_division, str) and zero_division == "warn":

‎sklearn/metrics/_plot/det_curve.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_plot/det_curve.py
+22-35Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
import scipy as sp
22

33
from .. import det_curve
4-
from .._base import _check_pos_label_consistency
4+
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
55

6-
from ...utils import check_matplotlib_support
7-
from ...utils._response import _get_response_values_binary
86

9-
10-
class DetCurveDisplay:
7+
class DetCurveDisplay(_BinaryClassifierCurveDisplayMixin):
118
"""DET curve visualization.
129
1310
It is recommend to use :func:`~sklearn.metrics.DetCurveDisplay.from_estimator`
@@ -163,15 +160,13 @@ def from_estimator(
163160
<...>
164161
>>> plt.show()
165162
"""
166-
check_matplotlib_support(f"{cls.__name__}.from_estimator")
167-
168-
name = estimator.__class__.__name__ if name is None else name
169-
170-
y_pred, pos_label = _get_response_values_binary(
163+
y_pred, pos_label, name = cls._validate_and_get_response_values(
171164
estimator,
172165
X,
173-
response_method,
166+
y,
167+
response_method=response_method,
174168
pos_label=pos_label,
169+
name=name,
175170
)
176171

177172
return cls.from_predictions(
@@ -259,22 +254,22 @@ def from_predictions(
259254
<...>
260255
>>> plt.show()
261256
"""
262-
check_matplotlib_support(f"{cls.__name__}.from_predictions")
257+
pos_label_validated, name = cls._validate_from_predictions_params(
258+
y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name
259+
)
260+
263261
fpr, fnr, _ = det_curve(
264262
y_true,
265263
y_pred,
266264
pos_label=pos_label,
267265
sample_weight=sample_weight,
268266
)
269267

270-
pos_label = _check_pos_label_consistency(pos_label, y_true)
271-
name = "Classifier" if name is None else name
272-
273268
viz = DetCurveDisplay(
274269
fpr=fpr,
275270
fnr=fnr,
276271
estimator_name=name,
277-
pos_label=pos_label,
272+
pos_label=pos_label_validated,
278273
)
279274

280275
return viz.plot(ax=ax, name=name, **kwargs)
@@ -300,18 +295,12 @@ def plot(self, ax=None, *, name=None, **kwargs):
300295
display : :class:`~sklearn.metrics.plot.DetCurveDisplay`
301296
Object that stores computed values.
302297
"""
303-
check_matplotlib_support("DetCurveDisplay.plot")
298+
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
304299

305-
name = self.estimator_name if name is None else name
306300
line_kwargs = {} if name is None else {"label": name}
307301
line_kwargs.update(**kwargs)
308302

309-
import matplotlib.pyplot as plt
310-
311-
if ax is None:
312-
_, ax = plt.subplots()
313-
314-
(self.line_,) = ax.plot(
303+
(self.line_,) = self.ax_.plot(
315304
sp.stats.norm.ppf(self.fpr),
316305
sp.stats.norm.ppf(self.fnr),
317306
**line_kwargs,
@@ -322,24 +311,22 @@ def plot(self, ax=None, *, name=None, **kwargs):
322311

323312
xlabel = "False Positive Rate" + info_pos_label
324313
ylabel = "False Negative Rate" + info_pos_label
325-
ax.set(xlabel=xlabel, ylabel=ylabel)
314+
self.ax_.set(xlabel=xlabel, ylabel=ylabel)
326315

327316
if "label" in line_kwargs:
328-
ax.legend(loc="lower right")
317+
self.ax_.legend(loc="lower right")
329318

330319
ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999]
331320
tick_locations = sp.stats.norm.ppf(ticks)
332321
tick_labels = [
333322
"{:.0%}".format(s) if (100 * s).is_integer() else "{:.1%}".format(s)
334323
for s in ticks
335324
]
336-
ax.set_xticks(tick_locations)
337-
ax.set_xticklabels(tick_labels)
338-
ax.set_xlim(-3, 3)
339-
ax.set_yticks(tick_locations)
340-
ax.set_yticklabels(tick_labels)
341-
ax.set_ylim(-3, 3)
342-
343-
self.ax_ = ax
344-
self.figure_ = ax.figure
325+
self.ax_.set_xticks(tick_locations)
326+
self.ax_.set_xticklabels(tick_labels)
327+
self.ax_.set_xlim(-3, 3)
328+
self.ax_.set_yticks(tick_locations)
329+
self.ax_.set_yticklabels(tick_labels)
330+
self.ax_.set_ylim(-3, 3)
331+
345332
return self

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.