Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 45f5ced

Browse filesBrowse files
betatimthomasjpfan
authored andcommitted
ENH Add Array API compatibility to MinMaxScaler (scikit-learn#26243)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 8665663 commit 45f5ced
Copy full SHA for 45f5ced

File tree

Expand file treeCollapse file tree

7 files changed

+174
-18
lines changed
Filter options
Expand file treeCollapse file tree

7 files changed

+174
-18
lines changed

‎doc/modules/array_api.rst

Copy file name to clipboardExpand all lines: doc/modules/array_api.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Estimators
9696
- :class:`decomposition.PCA` (with `svd_solver="full"`,
9797
`svd_solver="randomized"` and `power_iteration_normalizer="QR"`)
9898
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
99+
- :class:`preprocessing.MinMaxScaler`
99100

100101
Tools
101102
-----

‎doc/whats_new/v1.3.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.3.rst
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ Changelog
398398
- |Feature| Compute a custom out-of-bag score by passing a callable to
399399
:class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`,
400400
:class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor`.
401-
:pr:`25177` by :user:`Tim Head <betatim>`.
401+
:pr:`25177` by `Tim Head`_.
402402

403403
- |Feature| :class:`ensemble.GradientBoostingClassifier` now exposes
404404
out-of-bag scores via the `oob_scores_` or `oob_score_` attributes.

‎doc/whats_new/v1.4.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.4.rst
+8-2Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ Changelog
174174
is enabled and should be passed via the `params` parameter. :pr:`26896` by
175175
`Adrin Jalali`_.
176176

177+
- |Enhancement| :func:`sklearn.model_selection.train_test_split` now supports
178+
Array API compatible inputs. :pr:`26855` by `Tim Head`_.
179+
177180
:mod:`sklearn.neighbors`
178181
........................
179182

@@ -197,8 +200,11 @@ Changelog
197200
when `sparse_output=True` and the output is configured to be pandas.
198201
:pr:`26931` by `Thomas Fan`_.
199202

200-
- |Enhancement| :func:`sklearn.model_selection.train_test_split` now supports
201-
Array API compatible inputs. :pr:`26855` by `Tim Head`_.
203+
- |MajorFeature| :class:`preprocessing.MinMaxScaler` now
204+
supports the `Array API <https://data-apis.org/array-api/latest/>`_. Array API
205+
support is considered experimental and might evolve without being subject to
206+
our usual rolling deprecation cycle policy. See
207+
:ref:`array_api` for more details. :pr:`26243` by `Tim Head`_.
202208

203209
:mod:`sklearn.tree`
204210
...................

‎sklearn/preprocessing/_data.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/_data.py
+24-12Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
TransformerMixin,
2323
_fit_context,
2424
)
25-
from ..utils import check_array
25+
from ..utils import _array_api, check_array
26+
from ..utils._array_api import get_namespace
2627
from ..utils._param_validation import Interval, Options, StrOptions, validate_params
2728
from ..utils.extmath import _incremental_mean_and_var, row_norms
2829
from ..utils.sparsefuncs import (
@@ -103,16 +104,18 @@ def _handle_zeros_in_scale(scale, copy=True, constant_mask=None):
103104
if scale == 0.0:
104105
scale = 1.0
105106
return scale
106-
elif isinstance(scale, np.ndarray):
107+
# scale is an array
108+
else:
109+
xp, _ = get_namespace(scale)
107110
if constant_mask is None:
108111
# Detect near constant values to avoid dividing by a very small
109112
# value that could lead to surprising results and numerical
110113
# stability issues.
111-
constant_mask = scale < 10 * np.finfo(scale.dtype).eps
114+
constant_mask = scale < 10 * xp.finfo(scale.dtype).eps
112115

113116
if copy:
114117
# New array to avoid side-effects
115-
scale = scale.copy()
118+
scale = xp.asarray(scale, copy=True)
116119
scale[constant_mask] = 1.0
117120
return scale
118121

@@ -468,22 +471,24 @@ def partial_fit(self, X, y=None):
468471
"Consider using MaxAbsScaler instead."
469472
)
470473

474+
xp, _ = get_namespace(X)
475+
471476
first_pass = not hasattr(self, "n_samples_seen_")
472477
X = self._validate_data(
473478
X,
474479
reset=first_pass,
475-
dtype=FLOAT_DTYPES,
480+
dtype=_array_api.supported_float_dtypes(xp),
476481
force_all_finite="allow-nan",
477482
)
478483

479-
data_min = np.nanmin(X, axis=0)
480-
data_max = np.nanmax(X, axis=0)
484+
data_min = _array_api._nanmin(X, axis=0)
485+
data_max = _array_api._nanmax(X, axis=0)
481486

482487
if first_pass:
483488
self.n_samples_seen_ = X.shape[0]
484489
else:
485-
data_min = np.minimum(self.data_min_, data_min)
486-
data_max = np.maximum(self.data_max_, data_max)
490+
data_min = xp.minimum(self.data_min_, data_min)
491+
data_max = xp.maximum(self.data_max_, data_max)
487492
self.n_samples_seen_ += X.shape[0]
488493

489494
data_range = data_max - data_min
@@ -511,18 +516,20 @@ def transform(self, X):
511516
"""
512517
check_is_fitted(self)
513518

519+
xp, _ = get_namespace(X)
520+
514521
X = self._validate_data(
515522
X,
516523
copy=self.copy,
517-
dtype=FLOAT_DTYPES,
524+
dtype=_array_api.supported_float_dtypes(xp),
518525
force_all_finite="allow-nan",
519526
reset=False,
520527
)
521528

522529
X *= self.scale_
523530
X += self.min_
524531
if self.clip:
525-
np.clip(X, self.feature_range[0], self.feature_range[1], out=X)
532+
xp.clip(X, self.feature_range[0], self.feature_range[1], out=X)
526533
return X
527534

528535
def inverse_transform(self, X):
@@ -540,8 +547,13 @@ def inverse_transform(self, X):
540547
"""
541548
check_is_fitted(self)
542549

550+
xp, _ = get_namespace(X)
551+
543552
X = check_array(
544-
X, copy=self.copy, dtype=FLOAT_DTYPES, force_all_finite="allow-nan"
553+
X,
554+
copy=self.copy,
555+
dtype=_array_api.supported_float_dtypes(xp),
556+
force_all_finite="allow-nan",
545557
)
546558

547559
X -= self.min_

‎sklearn/preprocessing/tests/test_data.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/tests/test_data.py
+27Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
from sklearn.preprocessing._data import BOUNDS_THRESHOLD, _handle_zeros_in_scale
4242
from sklearn.svm import SVR
4343
from sklearn.utils import gen_batches, shuffle
44+
from sklearn.utils._array_api import (
45+
yield_namespace_device_dtype_combinations,
46+
)
4447
from sklearn.utils._testing import (
4548
_convert_container,
4649
assert_allclose,
@@ -51,6 +54,10 @@
5154
assert_array_less,
5255
skip_if_32bit,
5356
)
57+
from sklearn.utils.estimator_checks import (
58+
_get_check_estimator_ids,
59+
check_array_api_input_and_values,
60+
)
5461
from sklearn.utils.sparsefuncs import mean_variance_axis
5562

5663
iris = datasets.load_iris()
@@ -684,6 +691,26 @@ def test_standard_check_array_of_inverse_transform():
684691
scaler.inverse_transform(x)
685692

686693

694+
@pytest.mark.parametrize(
695+
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
696+
)
697+
@pytest.mark.parametrize(
698+
"check",
699+
[check_array_api_input_and_values],
700+
ids=_get_check_estimator_ids,
701+
)
702+
@pytest.mark.parametrize(
703+
"estimator",
704+
[MinMaxScaler()],
705+
ids=_get_check_estimator_ids,
706+
)
707+
def test_minmaxscaler_array_api_compliance(
708+
estimator, check, array_namespace, device, dtype
709+
):
710+
name = estimator.__class__.__name__
711+
check(name, estimator, array_namespace, device=device, dtype=dtype)
712+
713+
687714
def test_min_max_scaler_iris():
688715
X = iris.data
689716
scaler = MinMaxScaler()

‎sklearn/utils/_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/_array_api.py
+57-3Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _isdtype_single(dtype, kind, *, xp):
146146
for k in ("signed integer", "unsigned integer")
147147
)
148148
elif kind == "real floating":
149-
return dtype in {xp.float32, xp.float64}
149+
return dtype in supported_float_dtypes(xp)
150150
elif kind == "complex floating":
151151
# Some name spaces do not have complex, such as cupy.array_api
152152
# and numpy.array_api
@@ -167,14 +167,29 @@ def _isdtype_single(dtype, kind, *, xp):
167167
return dtype == kind
168168

169169

170+
def supported_float_dtypes(xp):
171+
"""Supported floating point types for the namespace
172+
173+
Note: float16 is not officially part of the Array API spec at the
174+
time of writing but scikit-learn estimators and functions can choose
175+
to accept it when xp.float16 is defined.
176+
177+
https://data-apis.org/array-api/latest/API_specification/data_types.html
178+
"""
179+
if hasattr(xp, "float16"):
180+
return (xp.float64, xp.float32, xp.float16)
181+
else:
182+
return (xp.float64, xp.float32)
183+
184+
170185
class _ArrayAPIWrapper:
171186
"""sklearn specific Array API compatibility wrapper
172187
173188
This wrapper makes it possible for scikit-learn maintainers to
174189
deal with discrepancies between different implementations of the
175-
Python array API standard and its evolution over time.
190+
Python Array API standard and its evolution over time.
176191
177-
The Python array API standard specification:
192+
The Python Array API standard specification:
178193
https://data-apis.org/array-api/latest/
179194
180195
Documentation of the NumPy implementation:
@@ -269,6 +284,9 @@ class _NumPyAPIWrapper:
269284
"uint16",
270285
"uint32",
271286
"uint64",
287+
# XXX: float16 is not part of the Array API spec but exposed by
288+
# some namespaces.
289+
"float16",
272290
"float32",
273291
"float64",
274292
"complex64",
@@ -394,6 +412,8 @@ def get_namespace(*arrays):
394412

395413
namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True
396414

415+
# These namespaces need additional wrapping to smooth out small differences
416+
# between implementations
397417
if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}:
398418
namespace = _ArrayAPIWrapper(namespace)
399419

@@ -466,6 +486,40 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
466486
return float(xp.sum(sample_score))
467487

468488

489+
def _nanmin(X, axis=None):
490+
# TODO: refactor once nan-aware reductions are standardized:
491+
# https://github.com/data-apis/array-api/issues/621
492+
xp, _ = get_namespace(X)
493+
if _is_numpy_namespace(xp):
494+
return xp.asarray(numpy.nanmin(X, axis=axis))
495+
496+
else:
497+
mask = xp.isnan(X)
498+
X = xp.min(xp.where(mask, xp.asarray(+xp.inf), X), axis=axis)
499+
# Replace Infs from all NaN slices with NaN again
500+
mask = xp.all(mask, axis=axis)
501+
if xp.any(mask):
502+
X = xp.where(mask, xp.asarray(xp.nan), X)
503+
return X
504+
505+
506+
def _nanmax(X, axis=None):
507+
# TODO: refactor once nan-aware reductions are standardized:
508+
# https://github.com/data-apis/array-api/issues/621
509+
xp, _ = get_namespace(X)
510+
if _is_numpy_namespace(xp):
511+
return xp.asarray(numpy.nanmax(X, axis=axis))
512+
513+
else:
514+
mask = xp.isnan(X)
515+
X = xp.max(xp.where(mask, xp.asarray(-xp.inf), X), axis=axis)
516+
# Replace Infs from all NaN slices with NaN again
517+
mask = xp.all(mask, axis=axis)
518+
if xp.any(mask):
519+
X = xp.where(mask, xp.asarray(xp.nan), X)
520+
return X
521+
522+
469523
def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None):
470524
"""Helper to support the order kwarg only for NumPy-backed arrays
471525

‎sklearn/utils/tests/test_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_array_api.py
+56Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy
24
import pytest
35
from numpy.testing import assert_allclose, assert_array_equal
@@ -9,8 +11,11 @@
911
_asarray_with_order,
1012
_convert_to_numpy,
1113
_estimator_with_converted_arrays,
14+
_nanmax,
15+
_nanmin,
1216
_NumPyAPIWrapper,
1317
get_namespace,
18+
supported_float_dtypes,
1419
)
1520
from sklearn.utils._testing import skip_if_array_api_compat_not_configured
1621

@@ -159,6 +164,54 @@ def test_asarray_with_order_ignored():
159164
assert not X_new_np.flags["F_CONTIGUOUS"]
160165

161166

167+
@skip_if_array_api_compat_not_configured
168+
@pytest.mark.parametrize(
169+
"library", ["numpy", "numpy.array_api", "cupy", "cupy.array_api", "torch"]
170+
)
171+
@pytest.mark.parametrize(
172+
"X,reduction,expected",
173+
[
174+
([1, 2, numpy.nan], _nanmin, 1),
175+
([1, -2, -numpy.nan], _nanmin, -2),
176+
([numpy.inf, numpy.inf], _nanmin, numpy.inf),
177+
(
178+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
179+
partial(_nanmin, axis=0),
180+
[1.0, 2.0, 3.0],
181+
),
182+
(
183+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
184+
partial(_nanmin, axis=1),
185+
[1.0, numpy.nan, 4.0],
186+
),
187+
([1, 2, numpy.nan], _nanmax, 2),
188+
([1, 2, numpy.nan], _nanmax, 2),
189+
([-numpy.inf, -numpy.inf], _nanmax, -numpy.inf),
190+
(
191+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
192+
partial(_nanmax, axis=0),
193+
[4.0, 5.0, 6.0],
194+
),
195+
(
196+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
197+
partial(_nanmax, axis=1),
198+
[3.0, numpy.nan, 6.0],
199+
),
200+
],
201+
)
202+
def test_nan_reductions(library, X, reduction, expected):
203+
"""Check NaN reductions like _nanmin and _nanmax"""
204+
xp = pytest.importorskip(library)
205+
206+
if isinstance(expected, list):
207+
expected = xp.asarray(expected)
208+
209+
with config_context(array_api_dispatch=True):
210+
result = reduction(xp.asarray(X))
211+
212+
assert_allclose(result, expected)
213+
214+
162215
@skip_if_array_api_compat_not_configured
163216
@pytest.mark.parametrize("library", ["cupy", "torch", "cupy.array_api"])
164217
def test_convert_to_numpy_gpu(library): # pragma: nocover
@@ -256,6 +309,9 @@ def test_get_namespace_array_api_isdtype(wrapper):
256309
assert xp.isdtype(xp.float64, "real floating")
257310
assert not xp.isdtype(xp.int32, "real floating")
258311

312+
for dtype in supported_float_dtypes(xp):
313+
assert xp.isdtype(dtype, "real floating")
314+
259315
assert xp.isdtype(xp.bool, "bool")
260316
assert not xp.isdtype(xp.float32, "bool")
261317

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.