Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Merged
1 change: 1 addition & 0 deletions 1 dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pybind11_add_module(${python_module_name} MODULE
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_roll.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp
Expand Down
36 changes: 13 additions & 23 deletions 36 dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# limitations under the License.


from itertools import chain, product, repeat
import operator
from itertools import chain, repeat

import numpy as np
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
Expand Down Expand Up @@ -426,10 +427,11 @@ def roll(X, shift, axis=None):
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
if axis is None:
shift = operator.index(shift)
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
)
hev, _ = ti._copy_usm_ndarray_for_reshape(
hev, _ = ti._copy_usm_ndarray_for_roll_1d(
src=X, dst=res, shift=shift, sycl_queue=X.sycl_queue
)
hev.wait()
Expand All @@ -438,31 +440,20 @@ def roll(X, shift, axis=None):
broadcasted = np.broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(X.ndim)}
shifts = [
0,
] * X.ndim
for sh, ax in broadcasted:
shifts[ax] += sh
rolls = [((np.s_[:], np.s_[:]),)] * X.ndim
for ax, offset in shifts.items():
offset %= X.shape[ax] or 1
if offset:
# (original, result), (original, result)
rolls[ax] = (
(np.s_[:-offset], np.s_[offset:]),
(np.s_[-offset:], np.s_[:offset]),
)

exec_q = X.sycl_queue
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
)
hev_list = []
for indices in product(*rolls):
arr_index, res_index = zip(*indices)
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=X[arr_index], dst=res[res_index], sycl_queue=X.sycl_queue
)
hev_list.append(hev)

dpctl.SyclEvent.wait_for(hev_list)
ht_e, _ = ti._copy_usm_ndarray_for_roll_nd(
src=X, dst=res, shifts=shifts, sycl_queue=exec_q
)
ht_e.wait()
return res


Expand Down Expand Up @@ -550,7 +541,6 @@ def _concat_axis_None(arrays):
hev, _ = ti._copy_usm_ndarray_for_reshape(
src=src_,
dst=res[fill_start:fill_end],
shift=0,
sycl_queue=exec_q,
)
fill_start = fill_end
Expand Down
2 changes: 1 addition & 1 deletion 2 dpctl/tensor/_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def reshape(X, shape, order="C", copy=None):
)
if order == "C":
hev, _ = _copy_usm_ndarray_for_reshape(
src=X, dst=flat_res, shift=0, sycl_queue=X.sycl_queue
src=X, dst=flat_res, sycl_queue=X.sycl_queue
)
hev.wait()
else:
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.