Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d29f78e

Browse filesBrowse files
betatimogriselthomasjpfan
authored
ENH Add common Array API tests and estimator tag (#26372)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 1e8a5b8 commit d29f78e
Copy full SHA for d29f78e

File tree

Expand file treeCollapse file tree

7 files changed

+181
-123
lines changed
Filter options
Expand file treeCollapse file tree

7 files changed

+181
-123
lines changed

‎doc/developers/develop.rst

Copy file name to clipboardExpand all lines: doc/developers/develop.rst
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,9 @@ The current set of estimator tags are:
535535
allow_nan (default=False)
536536
whether the estimator supports data with missing values encoded as np.nan
537537

538+
array_api_support (default=False)
539+
whether the estimator supports Array API compatible inputs.
540+
538541
binary_only (default=False)
539542
whether estimator supports binary classification but lacks multi-class
540543
classification support.

‎doc/modules/array_api.rst

Copy file name to clipboardExpand all lines: doc/modules/array_api.rst
+15Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,18 @@ Estimators with support for `Array API`-compatible inputs
9393
Coverage for more estimators is expected to grow over time. Please follow the
9494
dedicated `meta-issue on GitHub
9595
<https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress.
96+
97+
Common estimator checks
98+
=======================
99+
100+
Add the `array_api_support` tag to an estimator's set of tags to indicate that
101+
it supports the Array API. This will enable dedicated checks as part of the
102+
common tests to verify that the estimators result's are the same when using
103+
vanilla NumPy and Array API inputs.
104+
105+
To run these checks you need to install
106+
`array_api_compat <https://github.com/data-apis/array-api-compat>`_ in your
107+
test environment. To run the full set of checks you need to install both
108+
`PyTorch <https://pytorch.org/>`_ and `CuPy <https://cupy.dev/>`_ and have
109+
a GPU. Checks that can not be executed or have missing dependencies will be
110+
automatically skipped.

‎sklearn/discriminant_analysis.py

Copy file name to clipboardExpand all lines: sklearn/discriminant_analysis.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,9 @@ def decision_function(self, X):
745745
# Only override for the doc
746746
return super().decision_function(X)
747747

748+
def _more_tags(self):
749+
return {"array_api_support": True}
750+
748751

749752
class QuadraticDiscriminantAnalysis(ClassifierMixin, BaseEstimator):
750753
"""Quadratic Discriminant Analysis.

‎sklearn/tests/test_discriminant_analysis.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_discriminant_analysis.py
-122Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,12 @@
44

55
from scipy import linalg
66

7-
from sklearn.base import clone
8-
from sklearn._config import config_context
97
from sklearn.utils import check_random_state
108
from sklearn.utils._testing import assert_array_equal
119
from sklearn.utils._testing import assert_array_almost_equal
1210
from sklearn.utils._testing import assert_allclose
1311
from sklearn.utils._testing import assert_almost_equal
14-
from sklearn.utils._array_api import _convert_to_numpy
1512
from sklearn.utils._testing import _convert_container
16-
from sklearn.utils._testing import skip_if_array_api_compat_not_configured
1713

1814
from sklearn.datasets import make_blobs
1915
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
@@ -675,121 +671,3 @@ def test_get_feature_names_out():
675671
dtype=object,
676672
)
677673
assert_array_equal(names_out, expected_names_out)
678-
679-
680-
@skip_if_array_api_compat_not_configured
681-
@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
682-
def test_lda_array_api(array_namespace):
683-
"""Check that the array_api Array gives the same results as ndarrays."""
684-
xp = pytest.importorskip(array_namespace)
685-
686-
X_xp = xp.asarray(X)
687-
y_xp = xp.asarray(y3)
688-
689-
lda = LinearDiscriminantAnalysis()
690-
lda.fit(X, y3)
691-
692-
array_attributes = {
693-
key: value for key, value in vars(lda).items() if isinstance(value, np.ndarray)
694-
}
695-
696-
lda_xp = clone(lda)
697-
with config_context(array_api_dispatch=True):
698-
lda_xp.fit(X_xp, y_xp)
699-
700-
# Fitted-attributes which are arrays must have the same
701-
# namespace than the one of the training data.
702-
for key, attribute in array_attributes.items():
703-
lda_xp_param = getattr(lda_xp, key)
704-
assert hasattr(lda_xp_param, "__array_namespace__")
705-
706-
lda_xp_param_np = _convert_to_numpy(lda_xp_param, xp=xp)
707-
assert_allclose(
708-
attribute, lda_xp_param_np, err_msg=f"{key} not the same", atol=1e-3
709-
)
710-
711-
# Check predictions are the same
712-
methods = (
713-
"decision_function",
714-
"predict",
715-
"predict_log_proba",
716-
"predict_proba",
717-
"transform",
718-
)
719-
720-
for method in methods:
721-
result = getattr(lda, method)(X)
722-
with config_context(array_api_dispatch=True):
723-
result_xp = getattr(lda_xp, method)(X_xp)
724-
assert hasattr(
725-
result_xp, "__array_namespace__"
726-
), f"{method} did not output an array_namespace"
727-
728-
result_xp_np = _convert_to_numpy(result_xp, xp=xp)
729-
730-
assert_allclose(
731-
result,
732-
result_xp_np,
733-
err_msg=f"{method} did not the return the same result",
734-
atol=1e-5,
735-
)
736-
737-
738-
@skip_if_array_api_compat_not_configured
739-
@pytest.mark.parametrize("device", ["cuda", "cpu"])
740-
@pytest.mark.parametrize("dtype", ["float32", "float64"])
741-
def test_lda_array_torch(device, dtype):
742-
"""Check running on PyTorch Tensors gives the same results as NumPy"""
743-
torch = pytest.importorskip("torch")
744-
if device == "cuda" and not torch.has_cuda:
745-
pytest.skip("test requires cuda")
746-
747-
lda = LinearDiscriminantAnalysis()
748-
X_np = X6.astype(dtype)
749-
y_np = y6.astype(dtype)
750-
lda.fit(X_np, y_np)
751-
752-
X_torch = torch.asarray(X_np, device=device)
753-
y_torch = torch.asarray(y_np, device=device)
754-
lda_xp = clone(lda)
755-
with config_context(array_api_dispatch=True):
756-
lda_xp.fit(X_torch, y_torch)
757-
758-
array_attributes = {
759-
key: value for key, value in vars(lda).items() if isinstance(value, np.ndarray)
760-
}
761-
762-
for key, attribute in array_attributes.items():
763-
lda_xp_param = getattr(lda_xp, key)
764-
assert isinstance(lda_xp_param, torch.Tensor)
765-
assert lda_xp_param.device.type == device
766-
767-
lda_xp_param_np = _convert_to_numpy(lda_xp_param, xp=torch)
768-
assert_allclose(
769-
attribute, lda_xp_param_np, err_msg=f"{key} not the same", atol=1e-3
770-
)
771-
772-
# Check predictions are the same
773-
methods = (
774-
"decision_function",
775-
"predict",
776-
"predict_log_proba",
777-
"predict_proba",
778-
"transform",
779-
)
780-
for method in methods:
781-
result = getattr(lda, method)(X_np)
782-
with config_context(array_api_dispatch=True):
783-
result_xp = getattr(lda_xp, method)(X_torch)
784-
785-
assert isinstance(result_xp, torch.Tensor)
786-
assert result_xp.device.type == device
787-
788-
result_xp_np = _convert_to_numpy(result_xp, xp=torch)
789-
790-
assert_allclose(
791-
result,
792-
result_xp_np,
793-
err_msg=f"{method} did not the return the same result",
794-
atol=1e-6,
795-
)

‎sklearn/utils/_tags.py

Copy file name to clipboardExpand all lines: sklearn/utils/_tags.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
_DEFAULT_TAGS = {
4+
"array_api_support": False,
45
"non_deterministic": False,
56
"requires_positive_X": False,
67
"requires_positive_y": False,

‎sklearn/utils/estimator_checks.py

Copy file name to clipboardExpand all lines: sklearn/utils/estimator_checks.py
+124Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import warnings
2+
import importlib
3+
import itertools
24
import pickle
35
import re
46
from copy import deepcopy
@@ -58,6 +60,7 @@
5860
from ..utils.fixes import sp_version
5961
from ..utils.fixes import parse_version
6062
from ..utils.validation import check_is_fitted
63+
from ..utils._array_api import _convert_to_numpy, get_namespace, device as array_device
6164
from ..utils._param_validation import make_constraint
6265
from ..utils._param_validation import generate_invalid_param_val
6366
from ..utils._param_validation import InvalidParameterError
@@ -73,6 +76,7 @@
7376
from ..datasets import (
7477
load_iris,
7578
make_blobs,
79+
make_classification,
7680
make_multilabel_classification,
7781
make_regression,
7882
)
@@ -133,6 +137,21 @@ def _yield_checks(estimator):
133137

134138
yield check_estimator_get_tags_default_keys
135139

140+
if tags["array_api_support"]:
141+
for array_namespace in ["numpy.array_api", "cupy.array_api", "cupy", "torch"]:
142+
if array_namespace == "torch":
143+
for device, dtype in itertools.product(
144+
("cpu", "cuda"), ("float64", "float32")
145+
):
146+
yield partial(
147+
check_array_api_input,
148+
array_namespace=array_namespace,
149+
dtype=dtype,
150+
device=device,
151+
)
152+
else:
153+
yield partial(check_array_api_input, array_namespace=array_namespace)
154+
136155

137156
def _yield_classifier_checks(classifier):
138157
tags = _safe_tags(classifier)
@@ -831,6 +850,111 @@ def _generate_sparse_matrix(X_csr):
831850
yield sparse_format + "_64", X
832851

833852

853+
def check_array_api_input(
854+
name, estimator_orig, *, array_namespace, device=None, dtype="float64"
855+
):
856+
"""Check that the array_api Array gives the same results as ndarrays."""
857+
try:
858+
array_mod = importlib.import_module(array_namespace)
859+
except ModuleNotFoundError:
860+
raise SkipTest(
861+
f"{array_namespace} is not installed: not checking array_api input"
862+
)
863+
try:
864+
import array_api_compat # noqa
865+
except ImportError:
866+
raise SkipTest(
867+
"array_api_compat is not installed: not checking array_api input"
868+
)
869+
870+
# First create an array using the chosen array module and then get the
871+
# corresponding (compatibility wrapped) array namespace based on it.
872+
# This is because `cupy` is not the same as the compatibility wrapped
873+
# namespace of a CuPy array.
874+
xp = array_api_compat.get_namespace(array_mod.asarray(1))
875+
876+
if array_namespace == "torch" and device == "cuda" and not xp.has_cuda:
877+
raise SkipTest("PyTorch test requires cuda, which is not available")
878+
elif array_namespace in {"cupy", "cupy.array_api"}: # pragma: nocover
879+
import cupy
880+
881+
if cupy.cuda.runtime.getDeviceCount() == 0:
882+
raise SkipTest("CuPy test requires cuda, which is not available")
883+
884+
X, y = make_classification(random_state=42)
885+
X = X.astype(dtype, copy=False)
886+
887+
X = _enforce_estimator_tags_X(estimator_orig, X)
888+
y = _enforce_estimator_tags_y(estimator_orig, y)
889+
890+
est = clone(estimator_orig)
891+
892+
X_xp = xp.asarray(X, device=device)
893+
y_xp = xp.asarray(y, device=device)
894+
895+
est.fit(X, y)
896+
897+
array_attributes = {
898+
key: value for key, value in vars(est).items() if isinstance(value, np.ndarray)
899+
}
900+
901+
est_xp = clone(est)
902+
with config_context(array_api_dispatch=True):
903+
est_xp.fit(X_xp, y_xp)
904+
905+
# Fitted attributes which are arrays must have the same
906+
# namespace as the one of the training data.
907+
for key, attribute in array_attributes.items():
908+
est_xp_param = getattr(est_xp, key)
909+
assert (
910+
get_namespace(est_xp_param)[0] == get_namespace(X_xp)[0]
911+
), f"'{key}' attribute is in wrong namespace"
912+
913+
assert array_device(est_xp_param) == array_device(X_xp)
914+
915+
est_xp_param_np = _convert_to_numpy(est_xp_param, xp=xp)
916+
assert_allclose(
917+
attribute,
918+
est_xp_param_np,
919+
err_msg=f"{key} not the same",
920+
atol=np.finfo(X.dtype).eps * 100,
921+
)
922+
923+
# Check estimator methods, if supported, give the same results
924+
methods = (
925+
"decision_function",
926+
"predict",
927+
"predict_log_proba",
928+
"predict_proba",
929+
"transform",
930+
"inverse_transform",
931+
)
932+
933+
for method_name in methods:
934+
method = getattr(est, method_name, None)
935+
if method is None:
936+
continue
937+
938+
result = method(X)
939+
with config_context(array_api_dispatch=True):
940+
result_xp = getattr(est_xp, method_name)(X_xp)
941+
942+
assert (
943+
get_namespace(result_xp)[0] == get_namespace(X_xp)[0]
944+
), f"'{method}' output is in wrong namespace"
945+
946+
assert array_device(result_xp) == array_device(X_xp)
947+
948+
result_xp_np = _convert_to_numpy(result_xp, xp=xp)
949+
950+
assert_allclose(
951+
result,
952+
result_xp_np,
953+
err_msg=f"{method} did not the return the same result",
954+
atol=np.finfo(X.dtype).eps * 100,
955+
)
956+
957+
834958
def check_estimator_sparse_data(name, estimator_orig):
835959
rng = np.random.RandomState(0)
836960
X = rng.uniform(size=(40, 3))

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.