@@ -22,87 +22,74 @@ using common::is_complex;
22
22
namespace cuda {
23
23
24
24
template <typename T>
25
- void copyData (T *dst, const Array<T> &src) {
26
- if (src.elements () == 0 ) { return ; }
27
-
28
- // FIXME: Merge this with copyArray
29
- src.eval ();
30
-
31
- Array<T> out = src;
32
- const T *ptr = NULL ;
33
-
34
- if (src.isLinear () || // No offsets, No strides
35
- src.ndims () == 1 // Simple offset, no strides.
36
- ) {
37
- // A.get() gets data with offsets
38
- ptr = src.get ();
39
- } else {
40
- // FIXME: Think about implementing eval
41
- out = copyArray (src);
42
- ptr = out.get ();
25
+ void copyData (T *data, const Array<T> &src) {
26
+ if (src.elements () > 0 ) {
27
+ Array<T> lin = src.isReady () && src.isLinear () ? src : copyArray (src);
28
+ // out is now guaranteed linear
29
+ auto stream = cuda::getActiveStream ();
30
+ CUDA_CHECK (cudaMemcpyAsync (data, lin.get (), lin.elements () * sizeof (T),
31
+ cudaMemcpyDeviceToHost, stream));
32
+ CUDA_CHECK (cudaStreamSynchronize (stream));
43
33
}
44
-
45
- auto stream = cuda::getActiveStream ();
46
- CUDA_CHECK (cudaMemcpyAsync (dst, ptr, src.elements () * sizeof (T),
47
- cudaMemcpyDeviceToHost, stream));
48
- CUDA_CHECK (cudaStreamSynchronize (stream));
49
34
}
50
35
51
36
template <typename T>
52
37
Array<T> copyArray (const Array<T> &src) {
53
38
Array<T> out = createEmptyArray<T>(src.dims ());
54
- if (src.elements () == 0 ) { return out; }
55
-
56
- if (src.isLinear ()) {
57
- CUDA_CHECK (
58
- cudaMemcpyAsync (out.get (), src.get (), src.elements () * sizeof (T),
59
- cudaMemcpyDeviceToDevice, cuda::getActiveStream ()));
60
- } else {
39
+ if (src.isReady ()) {
61
40
kernel::memcopy<T>(out, src, src.ndims ());
41
+ } else {
42
+ Param<T> info (out.get (), src.dims ().dims , src.strides ().dims );
43
+ evalNodes (info, src.getNode ().get ());
62
44
}
63
45
return out;
64
46
}
65
47
66
48
template <typename T>
67
- void multiply_inplace (Array<T> &in , double val ) {
68
- kernel::copy<T, T>(in, in, in .ndims (), scalar<T>(0 ), val );
49
+ void multiply_inplace (Array<T> &src , double norm ) {
50
+ kernel::copy<T, T>(src, src, src .ndims (), scalar<T>(0 ), norm );
69
51
}
70
52
71
53
template <typename inType, typename outType>
72
54
struct copyWrapper {
73
- void operator ()(Array<outType> &out , Array<inType> const &in ) {
74
- kernel::copy<inType, outType>(out, in, in .ndims (), scalar<outType>(0 ),
75
- 1 );
55
+ void operator ()(Array<outType> &dst , Array<inType> const &src ) {
56
+ kernel::copy<inType, outType>(dst, src, src .ndims (), scalar<outType>(0 ),
57
+ 1.0 );
76
58
}
77
59
};
78
60
79
61
template <typename T>
80
62
struct copyWrapper <T, T> {
81
- void operator ()(Array<T> &out, Array<T> const &in) {
82
- if (out.isLinear () && in.isLinear () &&
83
- out.elements () == in.elements ()) {
84
- CUDA_CHECK (cudaMemcpyAsync (
85
- out.get (), in.get (), in.elements () * sizeof (T),
86
- cudaMemcpyDeviceToDevice, cuda::getActiveStream ()));
63
+ void operator ()(Array<T> &dst, Array<T> const &src) {
64
+ if (dst.isLinear () && src.isLinear () &&
65
+ dst.elements () == src.elements ()) {
66
+ if (src.isReady ()) {
67
+ CUDA_CHECK (cudaMemcpyAsync (
68
+ dst.get (), src.get (), src.elements () * sizeof (T),
69
+ cudaMemcpyDeviceToDevice, cuda::getActiveStream ()));
70
+ } else {
71
+ Param<T> info (dst.get (), src.dims ().dims , dst.strides ().dims );
72
+ evalNodes (info, src.getNode ().get ());
73
+ }
87
74
} else {
88
- kernel::copy<T, T>(out, in, in .ndims (), scalar<T>(0 ), 1 );
75
+ kernel::copy<T, T>(dst, src, src .ndims (), scalar<T>(0 ), 1.0 );
89
76
}
90
77
}
91
78
};
92
79
93
80
template <typename inType, typename outType>
94
- void copyArray (Array<outType> &out , Array<inType> const &in ) {
81
+ void copyArray (Array<outType> &dst , Array<inType> const &src ) {
95
82
static_assert (!(is_complex<inType>::value && !is_complex<outType>::value),
96
83
" Cannot copy from complex value to a non complex value" );
97
- ARG_ASSERT (1 , (in .ndims () == out. dims () .ndims ()));
84
+ ARG_ASSERT (1 , (src .ndims () == dst .ndims ()));
98
85
copyWrapper<inType, outType> copyFn;
99
- copyFn (out, in );
86
+ copyFn (dst, src );
100
87
}
101
88
102
- #define INSTANTIATE (T ) \
103
- template void copyData<T>(T * dst , const Array<T> &src); \
104
- template Array<T> copyArray<T>(const Array<T> &src); \
105
- template void multiply_inplace<T>(Array<T> & in , double norm);
89
+ #define INSTANTIATE (T ) \
90
+ template void copyData<T>(T * data , const Array<T> &src); \
91
+ template Array<T> copyArray<T>(const Array<T> &src); \
92
+ template void multiply_inplace<T>(Array<T> & src , double norm);
106
93
107
94
INSTANTIATE (float )
108
95
INSTANTIATE (double )
@@ -168,9 +155,9 @@ INSTANTIATE_COPY_ARRAY_COMPLEX(cfloat)
168
155
INSTANTIATE_COPY_ARRAY_COMPLEX (cdouble)
169
156
170
157
template <typename T>
171
- T getScalar (const Array<T> &in ) {
158
+ T getScalar (const Array<T> &src ) {
172
159
T retVal{};
173
- CUDA_CHECK (cudaMemcpyAsync (&retVal, in .get (), sizeof (T),
160
+ CUDA_CHECK (cudaMemcpyAsync (&retVal, src .get (), sizeof (T),
174
161
cudaMemcpyDeviceToHost,
175
162
cuda::getActiveStream ()));
176
163
CUDA_CHECK (cudaStreamSynchronize (cuda::getActiveStream ()));
0 commit comments