Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 47da3f0

Browse filesBrowse files
committed
ENH: CuPy creation functions to respect device= parameter
1 parent ecadf5b commit 47da3f0
Copy full SHA for 47da3f0

File tree

3 files changed

+57
-39
lines changed
Filter options

3 files changed

+57
-39
lines changed

‎array_api_compat/common/_aliases.py

Copy file name to clipboardExpand all lines: array_api_compat/common/_aliases.py
+23-23Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import inspect
88
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
99

10-
from ._helpers import _check_device, array_namespace
10+
from ._helpers import _device_ctx, array_namespace
1111
from ._helpers import device as _get_device
1212
from ._helpers import is_cupy_namespace as _is_cupy_namespace
1313
from ._typing import Array, Device, DType, Namespace
@@ -32,8 +32,8 @@ def arange(
3232
device: Device | None = None,
3333
**kwargs: object,
3434
) -> Array:
35-
_check_device(xp, device)
36-
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
35+
with _device_ctx(xp, device):
36+
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
3737

3838

3939
def empty(
@@ -44,8 +44,8 @@ def empty(
4444
device: Device | None = None,
4545
**kwargs: object,
4646
) -> Array:
47-
_check_device(xp, device)
48-
return xp.empty(shape, dtype=dtype, **kwargs)
47+
with _device_ctx(xp, device):
48+
return xp.empty(shape, dtype=dtype, **kwargs)
4949

5050

5151
def empty_like(
@@ -57,8 +57,8 @@ def empty_like(
5757
device: Device | None = None,
5858
**kwargs: object,
5959
) -> Array:
60-
_check_device(xp, device)
61-
return xp.empty_like(x, dtype=dtype, **kwargs)
60+
with _device_ctx(xp, device, like=x):
61+
return xp.empty_like(x, dtype=dtype, **kwargs)
6262

6363

6464
def eye(
@@ -72,8 +72,8 @@ def eye(
7272
device: Device | None = None,
7373
**kwargs: object,
7474
) -> Array:
75-
_check_device(xp, device)
76-
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
75+
with _device_ctx(xp, device):
76+
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
7777

7878

7979
def full(
@@ -85,8 +85,8 @@ def full(
8585
device: Device | None = None,
8686
**kwargs: object,
8787
) -> Array:
88-
_check_device(xp, device)
89-
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
88+
with _device_ctx(xp, device):
89+
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
9090

9191

9292
def full_like(
@@ -99,8 +99,8 @@ def full_like(
9999
device: Device | None = None,
100100
**kwargs: object,
101101
) -> Array:
102-
_check_device(xp, device)
103-
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
102+
with _device_ctx(xp, device, like=x):
103+
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
104104

105105

106106
def linspace(
@@ -115,8 +115,8 @@ def linspace(
115115
endpoint: bool = True,
116116
**kwargs: object,
117117
) -> Array:
118-
_check_device(xp, device)
119-
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
118+
with _device_ctx(xp, device):
119+
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
120120

121121

122122
def ones(
@@ -127,8 +127,8 @@ def ones(
127127
device: Device | None = None,
128128
**kwargs: object,
129129
) -> Array:
130-
_check_device(xp, device)
131-
return xp.ones(shape, dtype=dtype, **kwargs)
130+
with _device_ctx(xp, device):
131+
return xp.ones(shape, dtype=dtype, **kwargs)
132132

133133

134134
def ones_like(
@@ -140,8 +140,8 @@ def ones_like(
140140
device: Device | None = None,
141141
**kwargs: object,
142142
) -> Array:
143-
_check_device(xp, device)
144-
return xp.ones_like(x, dtype=dtype, **kwargs)
143+
with _device_ctx(xp, device, like=x):
144+
return xp.ones_like(x, dtype=dtype, **kwargs)
145145

146146

147147
def zeros(
@@ -152,8 +152,8 @@ def zeros(
152152
device: Device | None = None,
153153
**kwargs: object,
154154
) -> Array:
155-
_check_device(xp, device)
156-
return xp.zeros(shape, dtype=dtype, **kwargs)
155+
with _device_ctx(xp, device):
156+
return xp.zeros(shape, dtype=dtype, **kwargs)
157157

158158

159159
def zeros_like(
@@ -165,8 +165,8 @@ def zeros_like(
165165
device: Device | None = None,
166166
**kwargs: object,
167167
) -> Array:
168-
_check_device(xp, device)
169-
return xp.zeros_like(x, dtype=dtype, **kwargs)
168+
with _device_ctx(xp, device, like=x):
169+
return xp.zeros_like(x, dtype=dtype, **kwargs)
170170

171171

172172
# np.unique() is split into four functions in the array API:

‎array_api_compat/common/_helpers.py

Copy file name to clipboardExpand all lines: array_api_compat/common/_helpers.py
+32-15Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from __future__ import annotations
1010

11+
import contextlib
1112
import inspect
1213
import math
1314
import sys
@@ -657,26 +658,42 @@ def your_function(x, y):
657658
get_namespace = array_namespace
658659

659660

660-
def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction]
661-
"""
662-
Validate dummy device on device-less array backends.
661+
def _device_ctx(
662+
bare_xp: Namespace, device: Device, like: Array | None = None
663+
) -> Generator[None]:
664+
"""Context manager which changes the current device in CuPy.
663665
664-
Notes
665-
-----
666-
This function is also invoked by CuPy, which does have multiple devices
667-
if there are multiple GPUs available.
668-
However, CuPy multi-device support is currently impossible
669-
without using the global device or a context manager:
670-
671-
https://github.com/data-apis/array-api-compat/pull/293
666+
Used internally by array creation functions in common._aliases.
672667
"""
673-
if bare_xp is sys.modules.get("numpy"):
674-
if device not in ("cpu", None):
668+
if device is None:
669+
if like is None:
670+
return contextlib.nullcontext()
671+
device = _device(like)
672+
673+
if bare_xp is sys.modules.get('numpy'):
674+
if device != "cpu":
675675
raise ValueError(f"Unsupported device for NumPy: {device!r}")
676+
return contextlib.nullcontext()
676677

677-
elif bare_xp is sys.modules.get("dask.array"):
678-
if device not in ("cpu", _DASK_DEVICE, None):
678+
if bare_xp is sys.modules.get('dask.array'):
679+
if device not in ("cpu", _DASK_DEVICE):
679680
raise ValueError(f"Unsupported device for Dask: {device!r}")
681+
return contextlib.nullcontext()
682+
683+
if bare_xp is sys.modules.get('cupy'):
684+
if not isinstance(device, bare_xp.cuda.Device):
685+
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
686+
return device
687+
688+
# PyTorch doesn't have a "current device" context manager and you
689+
# can't use array creation functions from common._aliases.
690+
raise AssertionError("unreachable") # pragma: nocover
691+
692+
693+
def _check_device(bare_xp: Namespace, device: Device) -> None:
694+
"""Validate dummy device on device-less array backends."""
695+
with _device_ctx(bare_xp, device):
696+
pass
680697

681698

682699
# Placeholder object to represent the dask device

‎array_api_compat/cupy/_aliases.py

Copy file name to clipboardExpand all lines: array_api_compat/cupy/_aliases.py
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def asarray(
8686
See the corresponding documentation in the array library and/or the array API
8787
specification for more details.
8888
"""
89-
with cp.cuda.Device(device):
89+
like = obj if isinstance(obj, cp.ndarray) else None
90+
with _helpers._device_ctx(cp, device, like=like):
9091
if copy is None:
9192
return cp.asarray(obj, dtype=dtype, **kwargs)
9293
else:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.