Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 0e8defe

Browse filesBrowse files
committed
ENH: CuPy creation functions to respect device= parameter
1 parent 492db39 commit 0e8defe
Copy full SHA for 0e8defe

File tree

Expand file treeCollapse file tree

3 files changed

+61
-43
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+61
-43
lines changed

‎array_api_compat/common/_aliases.py

Copy file name to clipboardExpand all lines: array_api_compat/common/_aliases.py
+23-23Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import inspect
88
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
99

10-
from ._helpers import _check_device, array_namespace
10+
from ._helpers import _device_ctx, array_namespace
1111
from ._helpers import device as _get_device
1212
from ._helpers import is_cupy_namespace as _is_cupy_namespace
1313
from ._typing import Array, Device, DType, Namespace
@@ -32,8 +32,8 @@ def arange(
3232
device: Device | None = None,
3333
**kwargs: object,
3434
) -> Array:
35-
_check_device(xp, device)
36-
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
35+
with _device_ctx(xp, device):
36+
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
3737

3838

3939
def empty(
@@ -44,8 +44,8 @@ def empty(
4444
device: Device | None = None,
4545
**kwargs: object,
4646
) -> Array:
47-
_check_device(xp, device)
48-
return xp.empty(shape, dtype=dtype, **kwargs)
47+
with _device_ctx(xp, device):
48+
return xp.empty(shape, dtype=dtype, **kwargs)
4949

5050

5151
def empty_like(
@@ -57,8 +57,8 @@ def empty_like(
5757
device: Device | None = None,
5858
**kwargs: object,
5959
) -> Array:
60-
_check_device(xp, device)
61-
return xp.empty_like(x, dtype=dtype, **kwargs)
60+
with _device_ctx(xp, device, like=x):
61+
return xp.empty_like(x, dtype=dtype, **kwargs)
6262

6363

6464
def eye(
@@ -72,8 +72,8 @@ def eye(
7272
device: Device | None = None,
7373
**kwargs: object,
7474
) -> Array:
75-
_check_device(xp, device)
76-
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
75+
with _device_ctx(xp, device):
76+
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
7777

7878

7979
def full(
@@ -85,8 +85,8 @@ def full(
8585
device: Device | None = None,
8686
**kwargs: object,
8787
) -> Array:
88-
_check_device(xp, device)
89-
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
88+
with _device_ctx(xp, device):
89+
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
9090

9191

9292
def full_like(
@@ -99,8 +99,8 @@ def full_like(
9999
device: Device | None = None,
100100
**kwargs: object,
101101
) -> Array:
102-
_check_device(xp, device)
103-
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
102+
with _device_ctx(xp, device, like=x):
103+
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
104104

105105

106106
def linspace(
@@ -115,8 +115,8 @@ def linspace(
115115
endpoint: bool = True,
116116
**kwargs: object,
117117
) -> Array:
118-
_check_device(xp, device)
119-
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
118+
with _device_ctx(xp, device):
119+
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
120120

121121

122122
def ones(
@@ -127,8 +127,8 @@ def ones(
127127
device: Device | None = None,
128128
**kwargs: object,
129129
) -> Array:
130-
_check_device(xp, device)
131-
return xp.ones(shape, dtype=dtype, **kwargs)
130+
with _device_ctx(xp, device):
131+
return xp.ones(shape, dtype=dtype, **kwargs)
132132

133133

134134
def ones_like(
@@ -140,8 +140,8 @@ def ones_like(
140140
device: Device | None = None,
141141
**kwargs: object,
142142
) -> Array:
143-
_check_device(xp, device)
144-
return xp.ones_like(x, dtype=dtype, **kwargs)
143+
with _device_ctx(xp, device, like=x):
144+
return xp.ones_like(x, dtype=dtype, **kwargs)
145145

146146

147147
def zeros(
@@ -152,8 +152,8 @@ def zeros(
152152
device: Device | None = None,
153153
**kwargs: object,
154154
) -> Array:
155-
_check_device(xp, device)
156-
return xp.zeros(shape, dtype=dtype, **kwargs)
155+
with _device_ctx(xp, device):
156+
return xp.zeros(shape, dtype=dtype, **kwargs)
157157

158158

159159
def zeros_like(
@@ -165,8 +165,8 @@ def zeros_like(
165165
device: Device | None = None,
166166
**kwargs: object,
167167
) -> Array:
168-
_check_device(xp, device)
169-
return xp.zeros_like(x, dtype=dtype, **kwargs)
168+
with _device_ctx(xp, device, like=x):
169+
return xp.zeros_like(x, dtype=dtype, **kwargs)
170170

171171

172172
# np.unique() is split into four functions in the array API:

‎array_api_compat/common/_helpers.py

Copy file name to clipboardExpand all lines: array_api_compat/common/_helpers.py
+33-16Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
from __future__ import annotations
1010

11+
import contextlib
1112
import inspect
1213
import math
1314
import sys
1415
import warnings
15-
from collections.abc import Collection
16+
from collections.abc import Collection, Generator
1617
from typing import (
1718
TYPE_CHECKING,
1819
Any,
@@ -663,26 +664,42 @@ def your_function(x, y):
663664
get_namespace = array_namespace
664665

665666

666-
def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction]
667-
"""
668-
Validate dummy device on device-less array backends.
667+
def _device_ctx(
668+
bare_xp: Namespace, device: Device, like: Array | None = None
669+
) -> Generator[None]:
670+
"""Context manager which changes the current device in CuPy.
669671
670-
Notes
671-
-----
672-
This function is also invoked by CuPy, which does have multiple devices
673-
if there are multiple GPUs available.
674-
However, CuPy multi-device support is currently impossible
675-
without using the global device or a context manager:
676-
677-
https://github.com/data-apis/array-api-compat/pull/293
672+
Used internally by array creation functions in common._aliases.
678673
"""
679-
if bare_xp is sys.modules.get("numpy"):
680-
if device not in ("cpu", None):
674+
if device is None:
675+
if like is None:
676+
return contextlib.nullcontext()
677+
device = _device(like)
678+
679+
if bare_xp is sys.modules.get('numpy'):
680+
if device != "cpu":
681681
raise ValueError(f"Unsupported device for NumPy: {device!r}")
682+
return contextlib.nullcontext()
682683

683-
elif bare_xp is sys.modules.get("dask.array"):
684-
if device not in ("cpu", _DASK_DEVICE, None):
684+
if bare_xp is sys.modules.get('dask.array'):
685+
if device not in ("cpu", _DASK_DEVICE):
685686
raise ValueError(f"Unsupported device for Dask: {device!r}")
687+
return contextlib.nullcontext()
688+
689+
if bare_xp is sys.modules.get('cupy'):
690+
if not isinstance(device, bare_xp.cuda.Device):
691+
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
692+
return device
693+
694+
# PyTorch doesn't have a "current device" context manager and you
695+
# can't use array creation functions from common._aliases.
696+
raise AssertionError("unreachable") # pragma: nocover
697+
698+
699+
def _check_device(bare_xp: Namespace, device: Device) -> None:
700+
"""Validate dummy device on device-less array backends."""
701+
with _device_ctx(bare_xp, device):
702+
pass
686703

687704

688705
# Placeholder object to represent the dask device

‎array_api_compat/cupy/_aliases.py

Copy file name to clipboardExpand all lines: array_api_compat/cupy/_aliases.py
+5-4Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,11 @@ def asarray(
9090
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
9191

9292
like = obj if isinstance(obj, cp.ndarray) else None
93-
if copy is None:
94-
return cp.asarray(obj, dtype=dtype, **kwargs)
95-
else:
96-
return cp.array(obj, dtype=dtype, copy=True, **kwargs)
93+
with _helpers._device_ctx(cp, device, like=like):
94+
if copy is None:
95+
return cp.asarray(obj, dtype=dtype, **kwargs)
96+
else:
97+
return cp.array(obj, dtype=dtype, copy=True, **kwargs)
9798

9899

99100
def astype(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.