14
14
#include < handle.hpp>
15
15
#include < join.hpp>
16
16
#include < af/data.h>
17
+
17
18
#include < algorithm>
19
+ #include < climits>
18
20
#include < vector>
19
21
20
22
using af::dim4;
@@ -43,55 +45,43 @@ static inline af_array join_many(const int dim, const unsigned n_arrays,
43
45
vector<Array<T>> inputs_;
44
46
inputs_.reserve (n_arrays);
45
47
46
- for (unsigned i = 0 ; i < n_arrays; i++) {
47
- inputs_.push_back (getArray<T>(inputs[i]));
48
- if (inputs_.back ().isEmpty ()) { inputs_.pop_back (); }
48
+ dim_t dim_size{0 };
49
+ for (unsigned i{0 }; i < n_arrays; ++i) {
50
+ const Array<T> &iArray = getArray<T>(inputs[i]);
51
+ if (!iArray.isEmpty ()) {
52
+ inputs_.push_back (iArray);
53
+ dim_size += iArray.dims ().dims [dim];
54
+ }
49
55
}
50
56
51
57
// All dimensions except join dimension must be equal
52
58
// calculate odims size
53
- std::vector<af::dim4> idims (inputs_.size ());
54
- dim_t dim_size = 0 ;
55
- for (unsigned i = 0 ; i < idims.size (); i++) {
56
- idims[i] = inputs_[i].dims ();
57
- dim_size += idims[i][dim];
58
- }
59
-
60
- af::dim4 odims;
61
- for (int i = 0 ; i < 4 ; i++) {
62
- if (i == dim) {
63
- odims[i] = dim_size;
64
- } else {
65
- odims[i] = idims[0 ][i];
66
- }
67
- }
59
+ af::dim4 odims{inputs_[0 ].dims ()};
60
+ odims.dims [dim] = dim_size;
68
61
69
- Array<T> out = createEmptyArray<T>(odims);
62
+ Array<T> out{ createEmptyArray<T>(odims)} ;
70
63
join<T>(out, dim, inputs_);
71
64
return getHandle (out);
72
65
}
73
66
74
67
af_err af_join (af_array *out, const int dim, const af_array first,
75
68
const af_array second) {
76
69
try {
77
- const ArrayInfo &finfo = getInfo (first);
78
- const ArrayInfo &sinfo = getInfo (second);
79
- dim4 fdims = finfo.dims ();
80
- dim4 sdims = sinfo.dims ();
70
+ const ArrayInfo &finfo{ getInfo (first)} ;
71
+ const ArrayInfo &sinfo{ getInfo (second)} ;
72
+ const dim4 & fdims{ finfo.dims ()} ;
73
+ const dim4 & sdims{ sinfo.dims ()} ;
81
74
82
75
ARG_ASSERT (1 , dim >= 0 && dim < 4 );
83
76
ARG_ASSERT (2 , finfo.getType () == sinfo.getType ());
84
77
if (sinfo.elements () == 0 ) { return af_retain_array (out, first); }
85
-
86
78
if (finfo.elements () == 0 ) { return af_retain_array (out, second); }
87
-
88
- DIM_ASSERT (2 , sinfo.elements () > 0 );
89
- DIM_ASSERT (3 , finfo.elements () > 0 );
79
+ DIM_ASSERT (2 , finfo.elements () > 0 );
80
+ DIM_ASSERT (3 , sinfo.elements () > 0 );
90
81
91
82
// All dimensions except join dimension must be equal
92
- // Compute output dims
93
- for (int i = 0 ; i < 4 ; i++) {
94
- if (i != dim) { DIM_ASSERT (2 , fdims[i] == sdims[i]); }
83
+ for (int i{0 }; i < AF_MAX_DIMS; i++) {
84
+ if (i != dim) { DIM_ASSERT (2 , fdims.dims [i] == sdims.dims [i]); }
95
85
}
96
86
97
87
af_array output;
@@ -125,55 +115,46 @@ af_err af_join_many(af_array *out, const int dim, const unsigned n_arrays,
125
115
ARG_ASSERT (3 , inputs != nullptr );
126
116
127
117
if (n_arrays == 1 ) {
128
- af_array ret = nullptr ;
129
- AF_CHECK (af_retain_array (&ret, inputs[ 0 ] ));
118
+ af_array ret{ nullptr } ;
119
+ AF_CHECK (af_retain_array (&ret, * inputs));
130
120
std::swap (*out, ret);
131
121
return AF_SUCCESS;
132
122
}
133
123
134
- vector<ArrayInfo> info;
135
- info.reserve (n_arrays);
136
- vector<af::dim4> dims (n_arrays);
137
- for (unsigned i = 0 ; i < n_arrays; i++) {
138
- info.push_back (getInfo (inputs[i]));
139
- dims[i] = info[i].dims ();
140
- }
124
+ ARG_ASSERT (1 , dim >= 0 && dim < AF_MAX_DIMS);
125
+ ARG_ASSERT (2 , n_arrays > 0 );
141
126
142
- ARG_ASSERT (1 , dim >= 0 && dim < 4 );
143
-
144
- bool allEmpty = std::all_of (
145
- info.begin (), info.end (),
146
- [](const ArrayInfo &i) -> bool { return i.elements () <= 0 ; });
147
- if (allEmpty) {
127
+ const af_array *inputIt{inputs};
128
+ const af_array *inputEnd{inputs + n_arrays};
129
+ while ((inputIt != inputEnd) && (getInfo (*inputIt).elements () == 0 )) {
130
+ ++inputIt;
131
+ }
132
+ if (inputIt == inputEnd) {
133
+ // All arrays have 0 elements
148
134
af_array ret = nullptr ;
149
- AF_CHECK (af_retain_array (&ret, inputs[ 0 ] ));
135
+ AF_CHECK (af_retain_array (&ret, * inputs));
150
136
std::swap (*out, ret);
151
137
return AF_SUCCESS;
152
138
}
153
139
154
- auto first_valid_afinfo = std::find_if (
155
- info.begin (), info.end (),
156
- [](const ArrayInfo &i) -> bool { return i.elements () > 0 ; });
157
-
158
- af_dtype assertType = first_valid_afinfo->getType ();
159
- for (unsigned i = 1 ; i < n_arrays; i++) {
160
- if (info[i].elements () > 0 ) {
161
- ARG_ASSERT (3 , assertType == info[i].getType ());
162
- }
163
- }
164
-
165
- // All dimensions except join dimension must be equal
166
- af::dim4 assertDims = first_valid_afinfo->dims ();
167
- for (int i = 0 ; i < 4 ; i++) {
168
- if (i != dim) {
169
- for (unsigned j = 0 ; j < n_arrays; j++) {
170
- if (info[j].elements () > 0 ) {
171
- DIM_ASSERT (3 , assertDims[i] == dims[j][i]);
140
+ // inputIt points to first non empty array
141
+ const af_dtype assertType{getInfo (*inputIt).getType ()};
142
+ const dim4 &assertDims{getInfo (*inputIt).dims ()};
143
+
144
+ // Check all remaining arrays on assertType and assertDims
145
+ while (++inputIt != inputEnd) {
146
+ const ArrayInfo &info = getInfo (*inputIt);
147
+ if (info.elements () > 0 ) {
148
+ ARG_ASSERT (3 , assertType == info.getType ());
149
+ const dim4 &infoDims{getInfo (*inputIt).dims ()};
150
+ // All dimensions except join dimension must be equal
151
+ for (int i{0 }; i < AF_MAX_DIMS; i++) {
152
+ if (i != dim) {
153
+ DIM_ASSERT (3 , assertDims.dims [i] == infoDims.dims [i]);
172
154
}
173
155
}
174
156
}
175
157
}
176
-
177
158
af_array output;
178
159
179
160
switch (assertType) {
@@ -190,7 +171,7 @@ af_err af_join_many(af_array *out, const int dim, const unsigned n_arrays,
190
171
case u16 : output = join_many<ushort>(dim, n_arrays, inputs); break ;
191
172
case u8 : output = join_many<uchar>(dim, n_arrays, inputs); break ;
192
173
case f16 : output = join_many<half>(dim, n_arrays, inputs); break ;
193
- default : TYPE_ERROR (1 , info[ 0 ]. getType () );
174
+ default : TYPE_ERROR (1 , assertType );
194
175
}
195
176
swap (*out, output);
196
177
}
0 commit comments