Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f097f43

Browse filesBrowse files
authored
ENH Adds isdtype to Array API wrapper (#26029)
1 parent 22ea935 commit f097f43
Copy full SHA for f097f43

File tree

4 files changed

+133
-6
lines changed
Filter options

4 files changed

+133
-6
lines changed

‎sklearn/utils/_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/_array_api.py
+84-1Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,58 @@
44
import scipy.special as special
55

66

7+
def _is_numpy_namespace(xp):
8+
"""Return True if xp is backed by NumPy."""
9+
return xp.__name__ in {"numpy", "numpy.array_api"}
10+
11+
12+
def isdtype(dtype, kind, *, xp):
13+
"""Returns a boolean indicating whether a provided dtype is of type "kind".
14+
15+
Included in the v2022.12 of the Array API spec.
16+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
17+
"""
18+
if isinstance(kind, tuple):
19+
return any(_isdtype_single(dtype, k, xp=xp) for k in kind)
20+
else:
21+
return _isdtype_single(dtype, kind, xp=xp)
22+
23+
24+
def _isdtype_single(dtype, kind, *, xp):
25+
if isinstance(kind, str):
26+
if kind == "bool":
27+
return dtype == xp.bool
28+
elif kind == "signed integer":
29+
return dtype in {xp.int8, xp.int16, xp.int32, xp.int64}
30+
elif kind == "unsigned integer":
31+
return dtype in {xp.uint8, xp.uint16, xp.uint32, xp.uint64}
32+
elif kind == "integral":
33+
return any(
34+
_isdtype_single(dtype, k, xp=xp)
35+
for k in ("signed integer", "unsigned integer")
36+
)
37+
elif kind == "real floating":
38+
return dtype in {xp.float32, xp.float64}
39+
elif kind == "complex floating":
40+
# Some name spaces do not have complex, such as cupy.array_api
41+
# and numpy.array_api
42+
complex_dtypes = set()
43+
if hasattr(xp, "complex64"):
44+
complex_dtypes.add(xp.complex64)
45+
if hasattr(xp, "complex128"):
46+
complex_dtypes.add(xp.complex128)
47+
return dtype in complex_dtypes
48+
elif kind == "numeric":
49+
return any(
50+
_isdtype_single(dtype, k, xp=xp)
51+
for k in ("integral", "real floating", "complex floating")
52+
)
53+
else:
54+
raise ValueError(f"Unrecognized data type kind: {kind!r}")
55+
else:
56+
return dtype == kind
57+
58+
759
class _ArrayAPIWrapper:
860
"""sklearn specific Array API compatibility wrapper
961
@@ -48,6 +100,9 @@ def take(self, X, indices, *, axis):
48100
selected = [X[:, i] for i in indices]
49101
return self._namespace.stack(selected, axis=axis)
50102

103+
def isdtype(self, dtype, kind):
104+
return isdtype(dtype, kind, xp=self._namespace)
105+
51106

52107
class _NumPyAPIWrapper:
53108
"""Array API compat wrapper for any numpy version
@@ -60,8 +115,33 @@ class _NumPyAPIWrapper:
60115
See the `get_namespace()` public function for more details.
61116
"""
62117

118+
# Data types in spec
119+
# https://data-apis.org/array-api/latest/API_specification/data_types.html
120+
_DTYPES = {
121+
"int8",
122+
"int16",
123+
"int32",
124+
"int64",
125+
"uint8",
126+
"uint16",
127+
"uint32",
128+
"uint64",
129+
"float32",
130+
"float64",
131+
"complex64",
132+
"complex128",
133+
}
134+
63135
def __getattr__(self, name):
64-
return getattr(numpy, name)
136+
attr = getattr(numpy, name)
137+
# Convert to dtype objects
138+
if name in self._DTYPES:
139+
return numpy.dtype(attr)
140+
return attr
141+
142+
@property
143+
def bool(self):
144+
return numpy.bool_
65145

66146
def astype(self, x, dtype, *, copy=True, casting="unsafe"):
67147
# astype is not defined in the top level NumPy namespace
@@ -86,6 +166,9 @@ def unique_values(self, x):
86166
def concat(self, arrays, *, axis=None):
87167
return numpy.concatenate(arrays, axis=axis)
88168

169+
def isdtype(self, dtype, kind):
170+
return isdtype(dtype, kind, xp=self)
171+
89172

90173
def get_namespace(*arrays):
91174
"""Get namespace of arrays.

‎sklearn/utils/multiclass.py

Copy file name to clipboardExpand all lines: sklearn/utils/multiclass.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def type_of_target(y, input_name=""):
374374
suffix = "" # [1, 2, 3] or [[1], [2], [3]]
375375

376376
# Check float and contains non-integer float values
377-
if y.dtype.kind == "f":
377+
if xp.isdtype(y.dtype, "real floating"):
378378
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
379379
data = y.data if issparse(y) else y
380380
if xp.any(data != data.astype(int)):

‎sklearn/utils/tests/test_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_array_api.py
+39Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,42 @@ def test_convert_estimator_to_array_api():
187187

188188
new_est = _estimator_with_converted_arrays(est, lambda array: xp.asarray(array))
189189
assert hasattr(new_est.X_, "__array_namespace__")
190+
191+
192+
@pytest.mark.parametrize("wrapper", [_ArrayAPIWrapper, _NumPyApiWrapper])
193+
def test_get_namespace_array_api_isdtype(wrapper):
194+
"""Test isdtype implementation from _ArrayAPIWrapper and _NumPyApiWrapper."""
195+
196+
if wrapper == _ArrayAPIWrapper:
197+
xp_ = pytest.importorskip("numpy.array_api")
198+
xp = _ArrayAPIWrapper(xp_)
199+
else:
200+
xp = _NumPyApiWrapper()
201+
202+
assert xp.isdtype(xp.float32, xp.float32)
203+
assert xp.isdtype(xp.float32, "real floating")
204+
assert xp.isdtype(xp.float64, "real floating")
205+
assert not xp.isdtype(xp.int32, "real floating")
206+
207+
assert xp.isdtype(xp.bool, "bool")
208+
assert not xp.isdtype(xp.float32, "bool")
209+
210+
assert xp.isdtype(xp.int16, "signed integer")
211+
assert not xp.isdtype(xp.uint32, "signed integer")
212+
213+
assert xp.isdtype(xp.uint16, "unsigned integer")
214+
assert not xp.isdtype(xp.int64, "unsigned integer")
215+
216+
assert xp.isdtype(xp.int64, "numeric")
217+
assert xp.isdtype(xp.float32, "numeric")
218+
assert xp.isdtype(xp.uint32, "numeric")
219+
220+
assert not xp.isdtype(xp.float32, "complex floating")
221+
222+
if wrapper == _NumPyApiWrapper:
223+
assert not xp.isdtype(xp.int8, "complex floating")
224+
assert xp.isdtype(xp.complex64, "complex floating")
225+
assert xp.isdtype(xp.complex128, "complex floating")
226+
227+
with pytest.raises(ValueError, match="Unrecognized data type"):
228+
assert xp.isdtype(xp.int16, "unknown")

‎sklearn/utils/validation.py

Copy file name to clipboardExpand all lines: sklearn/utils/validation.py
+9-4Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ..exceptions import DataConversionWarning
3232
from ..utils._array_api import get_namespace
3333
from ..utils._array_api import _asarray_with_order
34+
from ..utils._array_api import _is_numpy_namespace
3435
from ._isfinite import cy_isfinite, FiniteStatus
3536

3637
FLOAT_DTYPES = (np.float64, np.float32, np.float16)
@@ -111,7 +112,7 @@ def _assert_all_finite(
111112
raise ValueError("Input contains NaN")
112113

113114
# We need only consider float arrays, hence can early return for all else.
114-
if X.dtype.kind not in "fc":
115+
if not xp.isdtype(X.dtype, ("real floating", "complex floating")):
115116
return
116117

117118
# First try an O(n) time, O(1) space solution for the common case that
@@ -759,7 +760,7 @@ def check_array(
759760
dtype_numeric = isinstance(dtype, str) and dtype == "numeric"
760761

761762
dtype_orig = getattr(array, "dtype", None)
762-
if not hasattr(dtype_orig, "kind"):
763+
if not is_array_api and not hasattr(dtype_orig, "kind"):
763764
# not a data type (e.g. a column named dtype in a pandas DataFrame)
764765
dtype_orig = None
765766

@@ -832,6 +833,10 @@ def check_array(
832833
)
833834
)
834835

836+
if dtype is not None and _is_numpy_namespace(xp):
837+
# convert to dtype object to conform to Array API to be use `xp.isdtype` later
838+
dtype = np.dtype(dtype)
839+
835840
estimator_name = _check_estimator_name(estimator)
836841
context = " by %s" % estimator_name if estimator is not None else ""
837842

@@ -875,12 +880,12 @@ def check_array(
875880
with warnings.catch_warnings():
876881
try:
877882
warnings.simplefilter("error", ComplexWarning)
878-
if dtype is not None and np.dtype(dtype).kind in "iu":
883+
if dtype is not None and xp.isdtype(dtype, "integral"):
879884
# Conversion float -> int should not contain NaN or
880885
# inf (numpy#14412). We cannot use casting='safe' because
881886
# then conversion float -> int would be disallowed.
882887
array = _asarray_with_order(array, order=order, xp=xp)
883-
if array.dtype.kind == "f":
888+
if xp.isdtype(array.dtype, ("real floating", "complex floating")):
884889
_assert_all_finite(
885890
array,
886891
allow_nan=False,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.