Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c4697fb

Browse filesBrowse files
willybornumar456
authored andcommitted
OPT: join
1 parent 9af3288 commit c4697fb
Copy full SHA for c4697fb

File tree

Expand file treeCollapse file tree

11 files changed

+397
-370
lines changed
Filter options
Expand file treeCollapse file tree

11 files changed

+397
-370
lines changed

‎src/api/c/join.cpp

Copy file name to clipboardExpand all lines: src/api/c/join.cpp
+47-66Lines changed: 47 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
#include <handle.hpp>
1515
#include <join.hpp>
1616
#include <af/data.h>
17+
1718
#include <algorithm>
19+
#include <climits>
1820
#include <vector>
1921

2022
using af::dim4;
@@ -43,55 +45,43 @@ static inline af_array join_many(const int dim, const unsigned n_arrays,
4345
vector<Array<T>> inputs_;
4446
inputs_.reserve(n_arrays);
4547

46-
for (unsigned i = 0; i < n_arrays; i++) {
47-
inputs_.push_back(getArray<T>(inputs[i]));
48-
if (inputs_.back().isEmpty()) { inputs_.pop_back(); }
48+
dim_t dim_size{0};
49+
for (unsigned i{0}; i < n_arrays; ++i) {
50+
const Array<T> &iArray = getArray<T>(inputs[i]);
51+
if (!iArray.isEmpty()) {
52+
inputs_.push_back(iArray);
53+
dim_size += iArray.dims().dims[dim];
54+
}
4955
}
5056

5157
// All dimensions except join dimension must be equal
5258
// calculate odims size
53-
std::vector<af::dim4> idims(inputs_.size());
54-
dim_t dim_size = 0;
55-
for (unsigned i = 0; i < idims.size(); i++) {
56-
idims[i] = inputs_[i].dims();
57-
dim_size += idims[i][dim];
58-
}
59-
60-
af::dim4 odims;
61-
for (int i = 0; i < 4; i++) {
62-
if (i == dim) {
63-
odims[i] = dim_size;
64-
} else {
65-
odims[i] = idims[0][i];
66-
}
67-
}
59+
af::dim4 odims{inputs_[0].dims()};
60+
odims.dims[dim] = dim_size;
6861

69-
Array<T> out = createEmptyArray<T>(odims);
62+
Array<T> out{createEmptyArray<T>(odims)};
7063
join<T>(out, dim, inputs_);
7164
return getHandle(out);
7265
}
7366

7467
af_err af_join(af_array *out, const int dim, const af_array first,
7568
const af_array second) {
7669
try {
77-
const ArrayInfo &finfo = getInfo(first);
78-
const ArrayInfo &sinfo = getInfo(second);
79-
dim4 fdims = finfo.dims();
80-
dim4 sdims = sinfo.dims();
70+
const ArrayInfo &finfo{getInfo(first)};
71+
const ArrayInfo &sinfo{getInfo(second)};
72+
const dim4 &fdims{finfo.dims()};
73+
const dim4 &sdims{sinfo.dims()};
8174

8275
ARG_ASSERT(1, dim >= 0 && dim < 4);
8376
ARG_ASSERT(2, finfo.getType() == sinfo.getType());
8477
if (sinfo.elements() == 0) { return af_retain_array(out, first); }
85-
8678
if (finfo.elements() == 0) { return af_retain_array(out, second); }
87-
88-
DIM_ASSERT(2, sinfo.elements() > 0);
89-
DIM_ASSERT(3, finfo.elements() > 0);
79+
DIM_ASSERT(2, finfo.elements() > 0);
80+
DIM_ASSERT(3, sinfo.elements() > 0);
9081

9182
// All dimensions except join dimension must be equal
92-
// Compute output dims
93-
for (int i = 0; i < 4; i++) {
94-
if (i != dim) { DIM_ASSERT(2, fdims[i] == sdims[i]); }
83+
for (int i{0}; i < AF_MAX_DIMS; i++) {
84+
if (i != dim) { DIM_ASSERT(2, fdims.dims[i] == sdims.dims[i]); }
9585
}
9686

9787
af_array output;
@@ -125,55 +115,46 @@ af_err af_join_many(af_array *out, const int dim, const unsigned n_arrays,
125115
ARG_ASSERT(3, inputs != nullptr);
126116

127117
if (n_arrays == 1) {
128-
af_array ret = nullptr;
129-
AF_CHECK(af_retain_array(&ret, inputs[0]));
118+
af_array ret{nullptr};
119+
AF_CHECK(af_retain_array(&ret, *inputs));
130120
std::swap(*out, ret);
131121
return AF_SUCCESS;
132122
}
133123

134-
vector<ArrayInfo> info;
135-
info.reserve(n_arrays);
136-
vector<af::dim4> dims(n_arrays);
137-
for (unsigned i = 0; i < n_arrays; i++) {
138-
info.push_back(getInfo(inputs[i]));
139-
dims[i] = info[i].dims();
140-
}
124+
ARG_ASSERT(1, dim >= 0 && dim < AF_MAX_DIMS);
125+
ARG_ASSERT(2, n_arrays > 0);
141126

142-
ARG_ASSERT(1, dim >= 0 && dim < 4);
143-
144-
bool allEmpty = std::all_of(
145-
info.begin(), info.end(),
146-
[](const ArrayInfo &i) -> bool { return i.elements() <= 0; });
147-
if (allEmpty) {
127+
const af_array *inputIt{inputs};
128+
const af_array *inputEnd{inputs + n_arrays};
129+
while ((inputIt != inputEnd) && (getInfo(*inputIt).elements() == 0)) {
130+
++inputIt;
131+
}
132+
if (inputIt == inputEnd) {
133+
// All arrays have 0 elements
148134
af_array ret = nullptr;
149-
AF_CHECK(af_retain_array(&ret, inputs[0]));
135+
AF_CHECK(af_retain_array(&ret, *inputs));
150136
std::swap(*out, ret);
151137
return AF_SUCCESS;
152138
}
153139

154-
auto first_valid_afinfo = std::find_if(
155-
info.begin(), info.end(),
156-
[](const ArrayInfo &i) -> bool { return i.elements() > 0; });
157-
158-
af_dtype assertType = first_valid_afinfo->getType();
159-
for (unsigned i = 1; i < n_arrays; i++) {
160-
if (info[i].elements() > 0) {
161-
ARG_ASSERT(3, assertType == info[i].getType());
162-
}
163-
}
164-
165-
// All dimensions except join dimension must be equal
166-
af::dim4 assertDims = first_valid_afinfo->dims();
167-
for (int i = 0; i < 4; i++) {
168-
if (i != dim) {
169-
for (unsigned j = 0; j < n_arrays; j++) {
170-
if (info[j].elements() > 0) {
171-
DIM_ASSERT(3, assertDims[i] == dims[j][i]);
140+
// inputIt points to first non empty array
141+
const af_dtype assertType{getInfo(*inputIt).getType()};
142+
const dim4 &assertDims{getInfo(*inputIt).dims()};
143+
144+
// Check all remaining arrays on assertType and assertDims
145+
while (++inputIt != inputEnd) {
146+
const ArrayInfo &info = getInfo(*inputIt);
147+
if (info.elements() > 0) {
148+
ARG_ASSERT(3, assertType == info.getType());
149+
const dim4 &infoDims{getInfo(*inputIt).dims()};
150+
// All dimensions except join dimension must be equal
151+
for (int i{0}; i < AF_MAX_DIMS; i++) {
152+
if (i != dim) {
153+
DIM_ASSERT(3, assertDims.dims[i] == infoDims.dims[i]);
172154
}
173155
}
174156
}
175157
}
176-
177158
af_array output;
178159

179160
switch (assertType) {
@@ -190,7 +171,7 @@ af_err af_join_many(af_array *out, const int dim, const unsigned n_arrays,
190171
case u16: output = join_many<ushort>(dim, n_arrays, inputs); break;
191172
case u8: output = join_many<uchar>(dim, n_arrays, inputs); break;
192173
case f16: output = join_many<half>(dim, n_arrays, inputs); break;
193-
default: TYPE_ERROR(1, info[0].getType());
174+
default: TYPE_ERROR(1, assertType);
194175
}
195176
swap(*out, output);
196177
}

‎src/backend/cuda/CMakeLists.txt

Copy file name to clipboardExpand all lines: src/backend/cuda/CMakeLists.txt
+1-2Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ set(nvrtc_src
208208
${CMAKE_CURRENT_SOURCE_DIR}/kernel/index.cuh
209209
${CMAKE_CURRENT_SOURCE_DIR}/kernel/iota.cuh
210210
${CMAKE_CURRENT_SOURCE_DIR}/kernel/ireduce.cuh
211-
${CMAKE_CURRENT_SOURCE_DIR}/kernel/join.cuh
212211
${CMAKE_CURRENT_SOURCE_DIR}/kernel/lookup.cuh
213212
${CMAKE_CURRENT_SOURCE_DIR}/kernel/lu_split.cuh
214213
${CMAKE_CURRENT_SOURCE_DIR}/kernel/match_template.cuh
@@ -455,7 +454,6 @@ cuda_add_library(afcuda
455454
kernel/interp.hpp
456455
kernel/iota.hpp
457456
kernel/ireduce.hpp
458-
kernel/join.hpp
459457
kernel/lookup.hpp
460458
kernel/lu_split.hpp
461459
kernel/match_template.hpp
@@ -656,6 +654,7 @@ cuda_add_library(afcuda
656654
svd.hpp
657655
tile.cpp
658656
tile.hpp
657+
threadsMgt.hpp
659658
topk.hpp
660659
traits.hpp
661660
transform.hpp

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.