Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 245ac79

Browse filesBrowse files
authored
FIX Fixes check_array nonfinite checks with ArrayAPI specification (#25619)
* FIX Fixes check_array nonfinite checks with ArrayAPI specification * DOC Adds PR number * FIX Test on both cupy and numpy
1 parent ee7dd36 commit 245ac79
Copy full SHA for 245ac79

File tree

3 files changed

+26
-2
lines changed
Filter options

3 files changed

+26
-2
lines changed

‎doc/whats_new/v1.2.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.2.rst
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ Changelog
6565
when the global configuration sets `transform_output="pandas"`.
6666
:pr:`25500` by :user:`Guillaume Lemaitre <glemaitre>`.
6767

68+
:mod:`sklearn.utils`
69+
....................
70+
71+
- |Fix| Fixes a bug in :func:`utils.check_array` which now correctly performs
72+
non-finite validation with the Array API specification. :pr:`25619` by
73+
`Thomas Fan`_.
74+
6875
.. _changes_1_2_1:
6976

7077
Version 1.2.1

‎sklearn/utils/tests/test_validation.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_validation.py
+17Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
import scipy.sparse as sp
1515

16+
from sklearn._config import config_context
1617
from sklearn.utils._testing import assert_no_warnings
1718
from sklearn.utils._testing import ignore_warnings
1819
from sklearn.utils._testing import SkipTest
@@ -1759,3 +1760,19 @@ def test_boolean_series_remains_boolean():
17591760

17601761
assert res.dtype == expected.dtype
17611762
assert_array_equal(res, expected)
1763+
1764+
1765+
@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
1766+
def test_check_array_array_api_has_non_finite(array_namespace):
1767+
"""Checks that Array API arrays checks non-finite correctly."""
1768+
xp = pytest.importorskip(array_namespace)
1769+
1770+
X_nan = xp.asarray([[xp.nan, 1, 0], [0, xp.nan, 3]], dtype=xp.float32)
1771+
with config_context(array_api_dispatch=True):
1772+
with pytest.raises(ValueError, match="Input contains NaN."):
1773+
check_array(X_nan)
1774+
1775+
X_inf = xp.asarray([[xp.inf, 1, 0], [0, xp.inf, 3]], dtype=xp.float32)
1776+
with config_context(array_api_dispatch=True):
1777+
with pytest.raises(ValueError, match="infinity or a value too large"):
1778+
check_array(X_inf)

‎sklearn/utils/validation.py

Copy file name to clipboardExpand all lines: sklearn/utils/validation.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def _assert_all_finite(
131131
has_nan_error = False if allow_nan else out == FiniteStatus.has_nan
132132
has_inf = out == FiniteStatus.has_infinite
133133
else:
134-
has_inf = np.isinf(X).any()
135-
has_nan_error = False if allow_nan else xp.isnan(X).any()
134+
has_inf = xp.any(xp.isinf(X))
135+
has_nan_error = False if allow_nan else xp.any(xp.isnan(X))
136136
if has_inf or has_nan_error:
137137
if has_nan_error:
138138
type_err = "NaN"

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.