Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9b40cbc

Browse filesBrowse files
authored
MNT Update array-api-compat to 1.12 (scikit-learn#31388)
1 parent ff6bf36 commit 9b40cbc
Copy full SHA for 9b40cbc

31 files changed

+1823
-1103
lines changed

‎maint_tools/vendor_array_api_compat.sh

Copy file name to clipboardExpand all lines: maint_tools/vendor_array_api_compat.sh
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set -o nounset
66
set -o errexit
77

88
URL="https://github.com/data-apis/array-api-compat.git"
9-
VERSION="1.11.2"
9+
VERSION="1.12"
1010

1111
ROOT_DIR=sklearn/externals/array_api_compat
1212

‎sklearn/externals/array_api_compat/__init__.py

Copy file name to clipboardExpand all lines: sklearn/externals/array_api_compat/__init__.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
this implementation for the default when working with NumPy arrays.
1818
1919
"""
20-
__version__ = '1.11.2'
20+
__version__ = '1.12.0'
2121

2222
from .common import * # noqa: F401, F403

‎sklearn/externals/array_api_compat/_internal.py

Copy file name to clipboardExpand all lines: sklearn/externals/array_api_compat/_internal.py
+19-6Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
Internal helpers
33
"""
44

5+
from collections.abc import Callable
56
from functools import wraps
67
from inspect import signature
8+
from types import ModuleType
9+
from typing import TypeVar
710

8-
def get_xp(xp):
11+
_T = TypeVar("_T")
12+
13+
14+
def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
915
"""
1016
Decorator to automatically replace xp with the corresponding array module.
1117
@@ -22,14 +28,14 @@ def func(x, /, xp, kwarg=None):
2228
2329
"""
2430

25-
def inner(f):
31+
def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
2632
@wraps(f)
27-
def wrapped_f(*args, **kwargs):
33+
def wrapped_f(*args: object, **kwargs: object) -> object:
2834
return f(*args, xp=xp, **kwargs)
2935

3036
sig = signature(f)
3137
new_sig = sig.replace(
32-
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
38+
parameters=[par for i, par in sig.parameters.items() if i != "xp"]
3339
)
3440

3541
if wrapped_f.__doc__ is None:
@@ -40,7 +46,14 @@ def wrapped_f(*args, **kwargs):
4046
specification for more details.
4147
4248
"""
43-
wrapped_f.__signature__ = new_sig
44-
return wrapped_f
49+
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
50+
return wrapped_f # pyright: ignore[reportReturnType]
4551

4652
return inner
53+
54+
55+
__all__ = ["get_xp"]
56+
57+
58+
def __dir__() -> list[str]:
59+
return __all__
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from ._helpers import * # noqa: F403
1+
from ._helpers import * # noqa: F403

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.