Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 78102fd

Browse filesBrowse files
authored
FIX always return None as device when array API dispatch is disabled (scikit-learn#29119)
1 parent 8798dfe commit 78102fd
Copy full SHA for 78102fd

File tree

Expand file treeCollapse file tree

3 files changed

+44
-5
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+44
-5
lines changed

‎doc/whats_new/v1.5.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.5.rst
+9-1Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,20 @@ Version 1.5.1
2323
Changelog
2424
---------
2525

26+
:mod:`sklearn.metrics`
27+
......................
28+
29+
- |Fix| Fix a regression in :func:`metrics.r2_score`. Passing torch CPU tensors
30+
with array API dispatched disabled would complain about non-CPU devices
31+
instead of implicitly converting those inputs as regular NumPy arrays.
32+
:pr:`29119` by :user:`Olivier Grisel`.
33+
2634
:mod:`sklearn.model_selection`
2735
..............................
2836

2937
- |Fix| Fix a regression in :class:`model_selection.GridSearchCV` for parameter
3038
grids that have heterogeneous parameter values.
31-
:pr:`29078` by :user:`Loïc Estève <lesteve>`
39+
:pr:`29078` by :user:`Loïc Estève <lesteve>`.
3240

3341

3442
.. _changes_1_5:

‎sklearn/utils/_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/_array_api.py
+9-4Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -568,10 +568,15 @@ def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,))
568568

569569
skip_remove_kwargs = dict(remove_none=False, remove_types=[])
570570

571-
return (
572-
*get_namespace(*array_list, **skip_remove_kwargs),
573-
device(*array_list, **skip_remove_kwargs),
574-
)
571+
xp, is_array_api = get_namespace(*array_list, **skip_remove_kwargs)
572+
if is_array_api:
573+
return (
574+
xp,
575+
is_array_api,
576+
device(*array_list, **skip_remove_kwargs),
577+
)
578+
else:
579+
return xp, False, None
575580

576581

577582
def _expit(X, xp=None):

‎sklearn/utils/tests/test_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_array_api.py
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_ravel,
2323
device,
2424
get_namespace,
25+
get_namespace_and_device,
2526
indexing_dtype,
2627
supported_float_dtypes,
2728
yield_namespace_device_dtype_combinations,
@@ -540,3 +541,28 @@ def test_isin(
540541
)
541542

542543
assert_array_equal(_convert_to_numpy(result, xp=xp), expected)
544+
545+
546+
def test_get_namespace_and_device():
547+
# Use torch as a library with custom Device objects:
548+
torch = pytest.importorskip("torch")
549+
xp_torch = pytest.importorskip("array_api_compat.torch")
550+
some_torch_tensor = torch.arange(3, device="cpu")
551+
some_numpy_array = numpy.arange(3)
552+
553+
# When dispatch is disabled, get_namespace_and_device should return the
554+
# default NumPy wrapper namespace and no device. Our code will handle such
555+
# inputs via the usual __array__ interface without attempting to dispatch
556+
# via the array API.
557+
namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
558+
assert namespace is get_namespace(some_numpy_array)[0]
559+
assert not is_array_api
560+
assert device is None
561+
562+
# Otherwise, expose the torch namespace and device via array API compat
563+
# wrapper.
564+
with config_context(array_api_dispatch=True):
565+
namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
566+
assert namespace is xp_torch
567+
assert is_array_api
568+
assert device == some_torch_tensor.device

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.