Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8c31248

Browse filesBrowse files
committed
ENH: Simplify CuPy asarray and to_device
1 parent 205c967 commit 8c31248
Copy full SHA for 8c31248

File tree

Expand file treeCollapse file tree

3 files changed

+25
-56
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+25
-56
lines changed

‎array_api_compat/common/_helpers.py

Copy file name to clipboardExpand all lines: array_api_compat/common/_helpers.py
+17-31Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -781,42 +781,28 @@ def _cupy_to_device(
781781
/,
782782
stream: int | Any | None = None,
783783
) -> _CupyArray:
784-
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
785-
from cupy.cuda import Device as _Device # pyright: ignore
786-
from cupy.cuda import stream as stream_module # pyright: ignore
787-
from cupy_backends.cuda.api import runtime # pyright: ignore
784+
import cupy as cp
788785

789-
if device == x.device:
790-
return x
791-
elif device == "cpu":
786+
if device == "cpu":
792787
# allowing us to use `to_device(x, "cpu")`
793788
# is useful for portable test swapping between
794789
# host and device backends
795790
return x.get()
796-
elif not isinstance(device, _Device):
797-
raise ValueError(f"Unsupported device {device!r}")
798-
else:
799-
# see cupy/cupy#5985 for the reason how we handle device/stream here
800-
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
801-
prev_stream = None
802-
if stream is not None:
803-
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore
804-
# stream can be an int as specified in __dlpack__, or a CuPy stream
805-
if isinstance(stream, int):
806-
stream = cp.cuda.ExternalStream(stream) # pyright: ignore
807-
elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType]
808-
pass
809-
else:
810-
raise ValueError("the input stream is not recognized")
811-
stream.use() # pyright: ignore[reportUnknownMemberType]
812-
try:
813-
runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType]
814-
arr = x.copy()
815-
finally:
816-
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType]
817-
if stream is not None:
818-
prev_stream.use()
819-
return arr
791+
if not isinstance(device, cp.cuda.Device):
792+
raise TypeError(f"Unsupported device type {device!r}")
793+
794+
if stream is None:
795+
with device:
796+
return cp.asarray(x)
797+
798+
# stream can be an int as specified in __dlpack__, or a CuPy stream
799+
if isinstance(stream, int):
800+
stream = cp.cuda.ExternalStream(stream)
801+
elif not isinstance(stream, cp.cuda.Stream):
802+
raise TypeError(f"Unsupported stream type {stream!r}")
803+
804+
with device, stream:
805+
return cp.asarray(x)
820806

821807

822808
def _torch_to_device(

‎array_api_compat/cupy/_aliases.py

Copy file name to clipboardExpand all lines: array_api_compat/cupy/_aliases.py
+8-22Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@
6464
finfo = get_xp(cp)(_aliases.finfo)
6565
iinfo = get_xp(cp)(_aliases.iinfo)
6666

67-
_copy_default = object()
68-
6967

7068
# asarray also adds the copy keyword, which is not present in numpy 1.0.
7169
def asarray(
@@ -79,7 +77,7 @@ def asarray(
7977
*,
8078
dtype: Optional[DType] = None,
8179
device: Optional[Device] = None,
82-
copy: Optional[bool] = _copy_default,
80+
copy: Optional[bool] = None,
8381
**kwargs,
8482
) -> Array:
8583
"""
@@ -89,25 +87,13 @@ def asarray(
8987
specification for more details.
9088
"""
9189
with cp.cuda.Device(device):
92-
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
93-
# in asarray in numpy/_aliases.py.
94-
if copy is not _copy_default:
95-
# A future version of CuPy will change the meaning of copy=False
96-
# to mean no-copy. We don't know for certain what version it will
97-
# be yet, so to avoid breaking that version, we use a different
98-
# default value for copy so asarray(obj) with no copy kwarg will
99-
# always do the copy-if-needed behavior.
100-
101-
# This will still need to be updated to remove the
102-
# NotImplementedError for copy=False, but at least this won't
103-
# break the default or existing behavior.
104-
if copy is None:
105-
copy = False
106-
elif copy is False:
107-
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
108-
kwargs['copy'] = copy
109-
110-
return cp.array(obj, dtype=dtype, **kwargs)
90+
if copy is None:
91+
return cp.asarray(obj, dtype=dtype, **kwargs)
92+
else:
93+
res = cp.array(obj, dtype=dtype, copy=copy, **kwargs)
94+
if not copy and res is not obj:
95+
raise ValueError("Unable to avoid copy while creating an array as requested")
96+
return res
11197

11298

11399
def astype(

‎cupy-xfails.txt

Copy file name to clipboardExpand all lines: cupy-xfails.txt
-3Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)]
1111
# testsuite bug (https://github.com/data-apis/array-api-tests/issues/172)
1212
array_api_tests/test_array_object.py::test_getitem
1313

14-
# copy=False is not yet implemented
15-
array_api_tests/test_creation_functions.py::test_asarray_arrays
16-
1714
# attributes are np.float32 instead of float
1815
# (see also https://github.com/data-apis/array-api/issues/405)
1916
array_api_tests/test_data_type_functions.py::test_finfo[float32]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.