Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5692e59

Browse filesBrowse files
betatimogriselOmarManzoor
authored
ENH Add Array API compatibility tests for *SearchCV classes (#27096)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent 4ee1e14 commit 5692e59
Copy full SHA for 5692e59

File tree

8 files changed

+69
-4
lines changed
Filter options

8 files changed

+69
-4
lines changed

‎doc/modules/array_api.rst

Copy file name to clipboardExpand all lines: doc/modules/array_api.rst
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,17 @@ Estimators
9898
- :class:`preprocessing.MinMaxScaler`
9999
- :class:`preprocessing.Normalizer`
100100

101+
Meta-estimators
102+
---------------
103+
104+
Meta-estimators that accept Array API inputs conditioned on the fact that the
105+
base estimator also does:
106+
107+
- :class:`model_selection.GridSearchCV`
108+
- :class:`model_selection.RandomizedSearchCV`
109+
- :class:`model_selection.HalvingGridSearchCV`
110+
- :class:`model_selection.HalvingRandomSearchCV`
111+
101112
Metrics
102113
-------
103114

‎doc/whats_new/v1.6.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.6.rst
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ See :ref:`array_api` for more details.
4343

4444
- :class:`preprocessing.LabelEncoder` now supports Array API compatible inputs.
4545
:pr:`27381` by :user:`Omar Salman <OmarManzoor>`.
46+
- :class:`model_selection.GridSearchCV`,
47+
:class:`model_selection.RandomizedSearchCV`,
48+
:class:`model_selection.HalvingGridSearchCV` and
49+
:class:`model_selection.HalvingRandomSearchCV` now support Array API
50+
compatible inputs when their base estimators do. :pr:`27096` by :user:`Tim
51+
Head <betatim>` and :user:`Olivier Grisel <ogrisel>`.
4652

4753
Metadata Routing
4854
----------------

‎sklearn/model_selection/_search.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/_search.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ def _more_tags(self):
440440
"_xfail_checks": {
441441
"check_supervised_y_2d": "DataConversionWarning not caught"
442442
},
443+
"array_api_support": _safe_tags(self.estimator, "array_api_support"),
443444
}
444445

445446
def score(self, X, y=None, **params):

‎sklearn/model_selection/_split.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/_split.py
+9-1Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,15 @@ def __init__(self, n_splits=5, *, shuffle=False, random_state=None):
745745

746746
def _make_test_folds(self, X, y=None):
747747
rng = check_random_state(self.random_state)
748-
y = np.asarray(y)
748+
# XXX: as of now, cross-validation splitters only operate in NumPy-land
749+
# without attempting to leverage array API namespace features. However
750+
# they might be fed by array API inputs, e.g. in CV-enabled estimators so
751+
# we need the following explicit conversion:
752+
xp, is_array_api = get_namespace(y)
753+
if is_array_api:
754+
y = _convert_to_numpy(y, xp)
755+
else:
756+
y = np.asarray(y)
749757
type_of_target_y = type_of_target(y)
750758
allowed_target_types = ("binary", "multiclass")
751759
if type_of_target_y not in allowed_target_types:

‎sklearn/model_selection/_validation.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/_validation.py
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..metrics._scorer import _MultimetricScorer
3131
from ..preprocessing import LabelEncoder
3232
from ..utils import Bunch, _safe_indexing, check_random_state, indexable
33+
from ..utils._array_api import device, get_namespace
3334
from ..utils._param_validation import (
3435
HasMethods,
3536
Integral,
@@ -830,6 +831,13 @@ def _fit_and_score(
830831
fit_error : str or None
831832
Traceback str if the fit failed, None if the fit succeeded.
832833
"""
834+
xp, _ = get_namespace(X)
835+
X_device = device(X)
836+
837+
# Make sure that we can fancy index X even if train and test are provided
838+
# as NumPy arrays by NumPy only cross-validation splitters.
839+
train, test = xp.asarray(train, device=X_device), xp.asarray(test, device=X_device)
840+
833841
if not isinstance(error_score, numbers.Number) and error_score != "raise":
834842
raise ValueError(
835843
"error_score must be the string 'raise' or a numeric value. "

‎sklearn/model_selection/tests/test_search.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_search.py
+28Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
make_classification,
2424
make_multilabel_classification,
2525
)
26+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
2627
from sklearn.dummy import DummyClassifier
2728
from sklearn.ensemble import HistGradientBoostingClassifier
2829
from sklearn.exceptions import FitFailedWarning
@@ -73,11 +74,13 @@
7374
check_recorded_metadata,
7475
)
7576
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
77+
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
7678
from sklearn.utils._mocking import CheckingClassifier, MockDataFrame
7779
from sklearn.utils._testing import (
7880
MinimalClassifier,
7981
MinimalRegressor,
8082
MinimalTransformer,
83+
_array_api_for_tests,
8184
assert_allclose,
8285
assert_almost_equal,
8386
assert_array_almost_equal,
@@ -2718,3 +2721,28 @@ def test_search_with_estimators_issue_29157():
27182721
grid_search = GridSearchCV(pipe, grid_params, cv=2)
27192722
grid_search.fit(X, y)
27202723
assert grid_search.cv_results_["param_enc__enc"].dtype == object
2724+
2725+
2726+
@pytest.mark.parametrize(
2727+
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
2728+
)
2729+
@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
2730+
def test_array_api_search_cv_classifier(SearchCV, array_namespace, device, dtype):
2731+
xp = _array_api_for_tests(array_namespace, device)
2732+
2733+
X = np.arange(100).reshape((10, 10))
2734+
X_np = X.astype(dtype)
2735+
X_xp = xp.asarray(X_np, device=device)
2736+
2737+
# y should always be an integer, no matter what `dtype` is
2738+
y_np = np.array([0] * 5 + [1] * 5)
2739+
y_xp = xp.asarray(y_np, device=device)
2740+
2741+
with config_context(array_api_dispatch=True):
2742+
searcher = SearchCV(
2743+
LinearDiscriminantAnalysis(),
2744+
{"tol": [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]},
2745+
cv=2,
2746+
)
2747+
searcher.fit(X_xp, y_xp)
2748+
searcher.score(X_xp, y_xp)

‎sklearn/model_selection/tests/test_validation.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_validation.py
+3-2Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2100,13 +2100,14 @@ def test_fit_and_score_failing():
21002100
failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
21012101
# dummy X data
21022102
X = np.arange(1, 10)
2103+
train, test = np.arange(0, 5), np.arange(5, 9)
21032104
fit_and_score_args = dict(
21042105
estimator=failing_clf,
21052106
X=X,
21062107
y=None,
21072108
scorer=dict(),
2108-
train=None,
2109-
test=None,
2109+
train=train,
2110+
test=test,
21102111
verbose=0,
21112112
parameters=None,
21122113
fit_params=None,

‎sklearn/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_common.py
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,9 @@ def _generate_search_cv_instances():
330330
extra_params = (
331331
{"min_resources": "smallest"} if "min_resources" in init_params else {}
332332
)
333-
search_cv = SearchCV(Estimator(), param_grid, cv=2, **extra_params)
333+
search_cv = SearchCV(
334+
Estimator(), param_grid, cv=2, error_score="raise", **extra_params
335+
)
334336
set_random_state(search_cv)
335337
yield search_cv
336338

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.