Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d5673a1

Browse filesBrowse files
committed
ENH: CuPy creation functions to respect device= parameter
1 parent 205c967 commit d5673a1
Copy full SHA for d5673a1

File tree

Expand file treeCollapse file tree

3 files changed

+81
-93
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+81
-93
lines changed

‎array_api_compat/common/_aliases.py

Copy file name to clipboardExpand all lines: array_api_compat/common/_aliases.py
+23-23Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import inspect
88
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
99

10-
from ._helpers import _check_device, array_namespace
10+
from ._helpers import _device_ctx, array_namespace
1111
from ._helpers import device as _get_device
1212
from ._helpers import is_cupy_namespace as _is_cupy_namespace
1313
from ._typing import Array, Device, DType, Namespace
@@ -32,8 +32,8 @@ def arange(
3232
device: Device | None = None,
3333
**kwargs: object,
3434
) -> Array:
35-
_check_device(xp, device)
36-
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
35+
with _device_ctx(xp, device):
36+
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
3737

3838

3939
def empty(
@@ -44,8 +44,8 @@ def empty(
4444
device: Device | None = None,
4545
**kwargs: object,
4646
) -> Array:
47-
_check_device(xp, device)
48-
return xp.empty(shape, dtype=dtype, **kwargs)
47+
with _device_ctx(xp, device):
48+
return xp.empty(shape, dtype=dtype, **kwargs)
4949

5050

5151
def empty_like(
@@ -57,8 +57,8 @@ def empty_like(
5757
device: Device | None = None,
5858
**kwargs: object,
5959
) -> Array:
60-
_check_device(xp, device)
61-
return xp.empty_like(x, dtype=dtype, **kwargs)
60+
with _device_ctx(xp, device, like=x):
61+
return xp.empty_like(x, dtype=dtype, **kwargs)
6262

6363

6464
def eye(
@@ -72,8 +72,8 @@ def eye(
7272
device: Device | None = None,
7373
**kwargs: object,
7474
) -> Array:
75-
_check_device(xp, device)
76-
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
75+
with _device_ctx(xp, device):
76+
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
7777

7878

7979
def full(
@@ -85,8 +85,8 @@ def full(
8585
device: Device | None = None,
8686
**kwargs: object,
8787
) -> Array:
88-
_check_device(xp, device)
89-
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
88+
with _device_ctx(xp, device):
89+
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
9090

9191

9292
def full_like(
@@ -99,8 +99,8 @@ def full_like(
9999
device: Device | None = None,
100100
**kwargs: object,
101101
) -> Array:
102-
_check_device(xp, device)
103-
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
102+
with _device_ctx(xp, device, like=x):
103+
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
104104

105105

106106
def linspace(
@@ -115,8 +115,8 @@ def linspace(
115115
endpoint: bool = True,
116116
**kwargs: object,
117117
) -> Array:
118-
_check_device(xp, device)
119-
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
118+
with _device_ctx(xp, device):
119+
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
120120

121121

122122
def ones(
@@ -127,8 +127,8 @@ def ones(
127127
device: Device | None = None,
128128
**kwargs: object,
129129
) -> Array:
130-
_check_device(xp, device)
131-
return xp.ones(shape, dtype=dtype, **kwargs)
130+
with _device_ctx(xp, device):
131+
return xp.ones(shape, dtype=dtype, **kwargs)
132132

133133

134134
def ones_like(
@@ -140,8 +140,8 @@ def ones_like(
140140
device: Device | None = None,
141141
**kwargs: object,
142142
) -> Array:
143-
_check_device(xp, device)
144-
return xp.ones_like(x, dtype=dtype, **kwargs)
143+
with _device_ctx(xp, device, like=x):
144+
return xp.ones_like(x, dtype=dtype, **kwargs)
145145

146146

147147
def zeros(
@@ -152,8 +152,8 @@ def zeros(
152152
device: Device | None = None,
153153
**kwargs: object,
154154
) -> Array:
155-
_check_device(xp, device)
156-
return xp.zeros(shape, dtype=dtype, **kwargs)
155+
with _device_ctx(xp, device):
156+
return xp.zeros(shape, dtype=dtype, **kwargs)
157157

158158

159159
def zeros_like(
@@ -165,8 +165,8 @@ def zeros_like(
165165
device: Device | None = None,
166166
**kwargs: object,
167167
) -> Array:
168-
_check_device(xp, device)
169-
return xp.zeros_like(x, dtype=dtype, **kwargs)
168+
with _device_ctx(xp, device, like=x):
169+
return xp.zeros_like(x, dtype=dtype, **kwargs)
170170

171171

172172
# np.unique() is split into four functions in the array API:

‎array_api_compat/common/_helpers.py

Copy file name to clipboardExpand all lines: array_api_compat/common/_helpers.py
+48-47Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
from __future__ import annotations
1010

11+
import contextlib
1112
import inspect
1213
import math
1314
import sys
1415
import warnings
15-
from collections.abc import Collection
16+
from collections.abc import Collection, Generator
1617
from typing import (
1718
TYPE_CHECKING,
1819
Any,
@@ -663,26 +664,42 @@ def your_function(x, y):
663664
get_namespace = array_namespace
664665

665666

666-
def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction]
667-
"""
668-
Validate dummy device on device-less array backends.
669-
670-
Notes
671-
-----
672-
This function is also invoked by CuPy, which does have multiple devices
673-
if there are multiple GPUs available.
674-
However, CuPy multi-device support is currently impossible
675-
without using the global device or a context manager:
667+
def _device_ctx(
668+
bare_xp: Namespace, device: Device, like: Array | None = None
669+
) -> Generator[None]:
670+
"""Context manager which changes the current device in CuPy.
676671
677-
https://github.com/data-apis/array-api-compat/pull/293
672+
Used internally by array creation functions in common._aliases.
678673
"""
679-
if bare_xp is sys.modules.get("numpy"):
680-
if device not in ("cpu", None):
674+
if device is None:
675+
if like is None:
676+
return contextlib.nullcontext()
677+
device = _device(like)
678+
679+
if bare_xp is sys.modules.get('numpy'):
680+
if device != "cpu":
681681
raise ValueError(f"Unsupported device for NumPy: {device!r}")
682+
return contextlib.nullcontext()
682683

683-
elif bare_xp is sys.modules.get("dask.array"):
684-
if device not in ("cpu", _DASK_DEVICE, None):
684+
if bare_xp is sys.modules.get('dask.array'):
685+
if device not in ("cpu", _DASK_DEVICE):
685686
raise ValueError(f"Unsupported device for Dask: {device!r}")
687+
return contextlib.nullcontext()
688+
689+
if bare_xp is sys.modules.get('cupy'):
690+
if not isinstance(device, bare_xp.cuda.Device):
691+
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
692+
return device
693+
694+
# PyTorch doesn't have a "current device" context manager and you
695+
# can't use array creation functions from common._aliases.
696+
raise AssertionError("unreachable") # pragma: nocover
697+
698+
699+
def _check_device(bare_xp: Namespace, device: Device) -> None:
700+
"""Validate dummy device on device-less array backends."""
701+
with _device_ctx(bare_xp, device):
702+
pass
686703

687704

688705
# Placeholder object to represent the dask device
@@ -781,42 +798,26 @@ def _cupy_to_device(
781798
/,
782799
stream: int | Any | None = None,
783800
) -> _CupyArray:
784-
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
785-
from cupy.cuda import Device as _Device # pyright: ignore
786-
from cupy.cuda import stream as stream_module # pyright: ignore
787-
from cupy_backends.cuda.api import runtime # pyright: ignore
801+
import cupy as cp
788802

789-
if device == x.device:
790-
return x
791-
elif device == "cpu":
803+
if device == "cpu":
792804
# allowing us to use `to_device(x, "cpu")`
793805
# is useful for portable test swapping between
794806
# host and device backends
795807
return x.get()
796-
elif not isinstance(device, _Device):
797-
raise ValueError(f"Unsupported device {device!r}")
798-
else:
799-
# see cupy/cupy#5985 for the reason how we handle device/stream here
800-
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
801-
prev_stream = None
802-
if stream is not None:
803-
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore
804-
# stream can be an int as specified in __dlpack__, or a CuPy stream
805-
if isinstance(stream, int):
806-
stream = cp.cuda.ExternalStream(stream) # pyright: ignore
807-
elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType]
808-
pass
809-
else:
810-
raise ValueError("the input stream is not recognized")
811-
stream.use() # pyright: ignore[reportUnknownMemberType]
812-
try:
813-
runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType]
814-
arr = x.copy()
815-
finally:
816-
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType]
817-
if stream is not None:
818-
prev_stream.use()
819-
return arr
808+
if not isinstance(device, cp.cuda.Device):
809+
raise TypeError(f"Unsupported device {device!r}")
810+
811+
# stream can be an int as specified in __dlpack__, or a CuPy stream
812+
if isinstance(stream, int):
813+
stream = cp.cuda.ExternalStream(stream)
814+
elif stream is None:
815+
stream = contextlib.nullcontext()
816+
elif not isinstance(stream, cp.cuda.Stream):
817+
raise TypeError('the input stream is not recognized')
818+
819+
with device, stream:
820+
return cp.asarray(x)
820821

821822

822823
def _torch_to_device(

‎array_api_compat/cupy/_aliases.py

Copy file name to clipboardExpand all lines: array_api_compat/cupy/_aliases.py
+10-23Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@
6464
finfo = get_xp(cp)(_aliases.finfo)
6565
iinfo = get_xp(cp)(_aliases.iinfo)
6666

67-
_copy_default = object()
68-
6967

7068
# asarray also adds the copy keyword, which is not present in numpy 1.0.
7169
def asarray(
@@ -79,7 +77,7 @@ def asarray(
7977
*,
8078
dtype: Optional[DType] = None,
8179
device: Optional[Device] = None,
82-
copy: Optional[bool] = _copy_default,
80+
copy: Optional[bool] = None,
8381
**kwargs,
8482
) -> Array:
8583
"""
@@ -88,26 +86,15 @@ def asarray(
8886
See the corresponding documentation in the array library and/or the array API
8987
specification for more details.
9088
"""
91-
with cp.cuda.Device(device):
92-
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
93-
# in asarray in numpy/_aliases.py.
94-
if copy is not _copy_default:
95-
# A future version of CuPy will change the meaning of copy=False
96-
# to mean no-copy. We don't know for certain what version it will
97-
# be yet, so to avoid breaking that version, we use a different
98-
# default value for copy so asarray(obj) with no copy kwarg will
99-
# always do the copy-if-needed behavior.
100-
101-
# This will still need to be updated to remove the
102-
# NotImplementedError for copy=False, but at least this won't
103-
# break the default or existing behavior.
104-
if copy is None:
105-
copy = False
106-
elif copy is False:
107-
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
108-
kwargs['copy'] = copy
109-
110-
return cp.array(obj, dtype=dtype, **kwargs)
89+
if copy is False:
90+
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
91+
92+
like = obj if isinstance(obj, cp.ndarray) else None
93+
with _helpers._device_ctx(cp, device, like=like):
94+
if copy is None:
95+
return cp.asarray(obj, dtype=dtype, **kwargs)
96+
else:
97+
return cp.array(obj, dtype=dtype, copy=True, **kwargs)
11198

11299

113100
def astype(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.