Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit eb0d695

Browse filesBrowse files
ShadyBoukharyumar456
authored andcommitted
Added Index to ArrayFire and added some documentation
1 parent e5374fa commit eb0d695
Copy full SHA for eb0d695

File tree

5 files changed

+244
-4
lines changed
Filter options

5 files changed

+244
-4
lines changed

‎CMakeLists.txt

Copy file name to clipboardExpand all lines: CMakeLists.txt
+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ ADD_JAR(${AF_JAR}
3535
com/arrayfire/Graphics.java
3636
com/arrayfire/Window.java
3737
com/arrayfire/Seq.java
38+
com/arrayfire/Index.java
3839
com/arrayfire/AFLibLoader.java
3940
)
4041

‎com/arrayfire/ArrayFire.java

Copy file name to clipboardExpand all lines: com/arrayfire/ArrayFire.java
+187-3
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,244 @@
11
package com.arrayfire;
22

3-
import com.arrayfire.Util;
3+
import com.arrayfire.Index;
44

55
public class ArrayFire extends AFLibLoader {
66

7+
public static final Seq SPAN = Seq.span();
78
/* ************* Algorithm ************* */
89

9-
// Scalar return operations
10+
11+
/**
12+
* Sum all the elements in an Array, wraps {@link http://arrayfire.org/docs/group__reduce__func__sum.htm }
13+
* @param a the array to be summed
14+
* @return the sum
15+
* @throws Exception
16+
* @see Array
17+
*/
1018
public static double sumAll(Array a) throws Exception {
1119
return Algorithm.sumAll(a);
1220
}
1321

22+
/**
23+
* Finds the maximum value in an array, wraps {@link http://arrayfire.org/docs/group__reduce__func__max.htm}
24+
* @param a the input array
25+
* @return the maximum value
26+
* @throws Exception
27+
* @see Array
28+
*/
1429
public static double maxAll(Array a) throws Exception {
1530
return Algorithm.maxAll(a);
1631
}
1732

33+
/**
34+
* Finds the minimum value in an array, wraps {@link http://arrayfire.org/docs/group__reduce__func__min.htm}.
35+
* @param a the input array
36+
* @return the minimum value
37+
* @throws Exception
38+
* @see Array
39+
*/
1840
public static double minAll(Array a) throws Exception {
1941
return Algorithm.minAll(a);
2042
}
2143

44+
/**
45+
* Finds the sum of an array across 1 dimension and stores the result in a 2nd array.
46+
* @param res the array in which to store the result
47+
* @param a the input array
48+
* @param dim the dimension across which to find the sum
49+
* @throws Exception
50+
* @see Array
51+
*/
2252
public static void sum(Array res, Array a, int dim) throws Exception {
2353
Algorithm.sum(res, a, dim);
2454
}
2555

56+
/**
57+
* Finds the maximum values of an array across 1 dimension and stores the result in a 2nd array.
58+
* @param res the array in which to store the result.
59+
* @param a the input array
60+
* @param dim the dimenstion across which to find the maximum values
61+
* @throws Exception
62+
* @see Array
63+
*/
2664
public static void max(Array res, Array a, int dim) throws Exception {
2765
Algorithm.max(res, a, dim);
2866
}
2967

68+
/**
69+
* Finds the minimum values in an array across 1 dimenstion and stores the result in a 2nd array.
70+
* @param res the array in which to store the result
71+
* @param a the input array
72+
* @param dim the dimension across which to find the maximum values
73+
* @throws Exception
74+
* @see Array
75+
*/
3076
public static void min(Array res, Array a, int dim) throws Exception {
3177
Algorithm.min(res, a, dim);
3278
}
3379

80+
81+
/**
82+
* Finds the sum of values in an array across all dimenstion and stores the result in a 2nd array.
83+
* @param res the array in which to store the result
84+
* @param a the input array
85+
* @throws Exception
86+
* @see Array
87+
*/
3488
public static void sum(Array res, Array a) throws Exception {
3589
Algorithm.sum(res, a);
3690
}
3791

92+
/**
93+
* Finds the max values in an array across all dimenstion and stores the result in a 2nd array.
94+
* @param res the array in which to store the result
95+
* @param a the input array
96+
* @throws Exception
97+
* @see Array
98+
*/
3899
public static void max(Array res, Array a) throws Exception {
39100
Algorithm.max(res, a);
40101
}
41102

103+
/**
104+
* Finds the minimum values in an array across all dimenstion and stores the result in a 2nd array.
105+
* @param res the array in which to store the result
106+
* @param a the input array
107+
* @throws Exception
108+
* @see Array
109+
*/
42110
public static void min(Array res, Array a) throws Exception {
43111
Algorithm.min(res, a, 0);
44112
}
45113

114+
/**
115+
* Finds the indices of all non-zero values in an input array, wraps {@link http://arrayfire.org/docs/group__scan__func__where.htm}
116+
* @param in the input array
117+
* @return an array containing the indices of all non-zero values in the input array
118+
* @throws Exception
119+
*/
120+
public static Array where(final Array in) throws Exception {
121+
return Algorithm.where(in);
122+
}
123+
46124
/* ************* Arith ************* */
47125

126+
/**
127+
* Performs element-wise addition between 2 arrays and stores the result in another array, wraps {@link http://arrayfire.org/docs/group__arith__func__add.htm}
128+
* @param c the resulting array
129+
* @param a the lhs array
130+
* @param b the rhs array
131+
* @return the resulting array
132+
* @throws Exception
133+
*/
48134
public static void add(Array c, Array a, Array b) throws Exception {
49135
Arith.add(c, a, b);
50136
}
51137

138+
/**
139+
* Subtracts 2 arrays, storing the result in another array, wraps {@link http://arrayfire.org/docs/group__arith__func__sub.htm}
140+
* @param c the resulting array
141+
* @param a the lhs array
142+
* @param b the rhs array
143+
* @return the resulting array
144+
* @throws Exception
145+
*/
52146
public static void sub(Array c, Array a, Array b) throws Exception {
53147
Arith.sub(c, a, b);
54148
}
55149

150+
/**
151+
* Performs element-wise multiplication between 2 arrays, wraps {@link http://arrayfire.org/docs/group__arith__func__mul.htm}
152+
* @param c the resulting array
153+
* @param a the lhs array
154+
* @param b the rhs array
155+
* @return the resulting array
156+
* @throws Exception
157+
*/
56158
public static void mul(Array c, Array a, Array b) throws Exception {
57159
Arith.mul(c, a, b);
58160
}
59161

162+
/**
163+
* Divides one array by another array, wraps {@link http://arrayfire.org/docs/group__arith__func__div.htm}
164+
* @param c the resulting array
165+
* @param a the lhs array
166+
* @param b the rhs array
167+
* @return the resulting array
168+
* @throws Exception
169+
*/
60170
public static void div(Array c, Array a, Array b) throws Exception {
61171
Arith.div(c, a, b);
62172
}
63173

174+
/**
175+
* Checks if an array is less than or equal to another, wraps {@link http://arrayfire.org/docs/group__arith__func__le.htm}.
176+
* @param c the resulting array
177+
* @param a the lhs array
178+
* @param b the rhs array
179+
* @return the resulting array
180+
* @throws Exception
181+
*/
64182
public static void le(Array c, Array a, Array b) throws Exception {
65183
Arith.le(c, a, b);
66184
}
67185

186+
/**
187+
* Checks if an array is less than another, wraps {@link http://arrayfire.org/docs/group__arith__func__lt.htm}
188+
* @param c the resulting array
189+
* @param a the lhs array
190+
* @param b the rhs array
191+
* @return the resulting array
192+
* @throws Exception
193+
*/
68194
public static void lt(Array c, Array a, Array b) throws Exception {
69195
Arith.lt(c, a, b);
70196
}
71197

198+
/**
199+
* Checks if an array is greater than or equal to another, wraps {@link http://arrayfire.org/docs/group__arith__func__ge.htm}
200+
* @param c the resulting array
201+
* @param a the lhs array
202+
* @param b the rhs array
203+
* @return the resulting array
204+
* @throws Exception
205+
*/
72206
public static void ge(Array c, Array a, Array b) throws Exception {
73207
Arith.ge(c, a, b);
74208
}
75209

210+
/**
211+
* Checks if an array is greater than another, wraps {@link http://arrayfire.org/docs/group__arith__func__gt.htm}
212+
* @param c the resulting array
213+
* @param a the lhs array
214+
* @param b the rhs array
215+
* @return the resulting array
216+
* @throws Exception
217+
*/
76218
public static void gt(Array c, Array a, Array b) throws Exception {
77219
Arith.gt(c, a, b);
78220
}
79221

222+
/**
223+
* Checks if 2 input arrays are equal, wraps {@link http://arrayfire.org/docs/group__arith__func__eq.htm}
224+
* @param c the resulting array
225+
* @param a the lhs array
226+
* @param b the rhs array
227+
* @return the resulting array
228+
* @throws Exception
229+
*/
80230
public static void eq(Array c, Array a, Array b) throws Exception {
81231
Arith.eq(c, a, b);
82232
}
83233

234+
/**
235+
* Checks if 2 input arrays are not equal, wraps {@link http://arrayfire.org/docs/group__arith__func__neq.htm}
236+
* @param c the resulting array
237+
* @param a the lhs array
238+
* @param b the rhs array
239+
* @return the resulting array
240+
* @throws Exception
241+
*/
84242
public static void neq(Array c, Array a, Array b) throws Exception {
85243
Arith.neq(c, a, b);
86244
}
@@ -491,6 +649,33 @@ static public <T> T castResult(DoubleComplex res, Class<T> type) throws Exceptio
491649
return Statistics.castResult(res, type);
492650
}
493651

652+
/* ************* Index ************* */
653+
654+
655+
public static Array lookup(final Array in, final Array idx, int dim) throws Exception {
656+
return Index.lookup(in, idx, dim);
657+
}
658+
659+
public static Array lookup(final Array in, final Array idx) throws Exception {
660+
return Index.lookup(in, idx, 0);
661+
}
662+
663+
public static void copy(Array dst, final Array src, Index idx0, Index idx1, Index idx2, Index idx3) throws Exception {
664+
Index.copy(dst, src, idx0, idx1, idx2, idx3);
665+
}
666+
667+
public static void copy(Array dst, final Array src, Index idx0, Index idx1, Index idx2) throws Exception {
668+
Index.copy(dst, src, idx0, idx1, idx2, new Index());
669+
}
670+
671+
public static void copy(Array dst, final Array src, Index idx0, Index idx1) throws Exception {
672+
Index.copy(dst, src, idx0, idx1, new Index(), new Index());
673+
}
674+
675+
public static void copy(Array dst, final Array src, Index idx0) throws Exception {
676+
Index.copy(dst, src, idx0, new Index(), new Index(), new Index());
677+
}
678+
494679
// Utils
495680

496681
public static String toString(Array a, String delim) {
@@ -505,7 +690,6 @@ public static void info() {
505690
Util.info();
506691
}
507692

508-
509693
// Enums
510694

511695
public static enum Type {

‎src/CMakeLists.txt

Copy file name to clipboardExpand all lines: src/CMakeLists.txt
+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ ADD_LIBRARY(${AF_LIB} SHARED
2222
signal.cpp
2323
statistics.cpp
2424
graphics.cpp
25+
index.cpp
2526
util.cpp
2627
)
2728

‎src/java/java.cpp

Copy file name to clipboardExpand all lines: src/java/java.cpp
+53-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ enum class JavaType {
1111
String,
1212
Short,
1313
Void,
14-
Boolean
14+
Boolean,
15+
Seq
1516
};
1617

1718
static const char *mapJavaTypeToString(JavaType type) {
@@ -26,6 +27,7 @@ static const char *mapJavaTypeToString(JavaType type) {
2627
case JavaType::Short: return "S";
2728
case JavaType::Void: return "V";
2829
case JavaType::Boolean: return "B";
30+
case JavaType::Seq: return "Lcom/arrayfire/Seq";
2931
}
3032
}
3133

@@ -94,6 +96,56 @@ jobject createJavaObject(JNIEnv *env, JavaObjects objectType, Args... args) {
9496
} break;
9597
}
9698
}
99+
100+
af_index_t jIndexToCIndex(JNIEnv *env, jobject obj) {
101+
af_index_t index;
102+
jclass cls = env->GetObjectClass(obj);
103+
assert(cls == env->FindClass("com/arrayfire/Index"));
104+
105+
std::string getIsSeqSig = generateFunctionSignature(JavaType::Boolean, {});
106+
jmethodID getIsSeqId = env->GetMethodID(cls, "isSeq", getIsSeqSig.c_str());
107+
assert(getIsSeqId != NULL);
108+
index.isSeq = env->CallBooleanMethod(obj, getIsSeqId);
109+
110+
std::string getIsBatchSig = generateFunctionSignature(JavaType::Boolean, {});
111+
jmethodID getIsBatchId = env->GetMethodID(cls, "isBatch", getIsBatchSig.c_str());
112+
assert(getIsBatchId != NULL);
113+
index.isBatch = env->CallBooleanMethod(obj, getIsBatchId);
114+
115+
if (index.isSeq) {
116+
// get seq object
117+
std::string getSeqSig = generateFunctionSignature(JavaType::Seq, {});
118+
jmethodID getSeqId = env->GetMethodID(cls, "getSeq", getSeqSig.c_str());
119+
assert(getSeqId != NULL);
120+
jobject seq = env->CallObjectMethod(obj, getSeqId);
121+
122+
// get seq fields
123+
jclass seqCls = env->GetObjectClass(seq);
124+
assert(seqCls == env->FindClass("com/arrayfire/Seq"));
125+
126+
jfieldID beginID = env->GetFieldID(seqCls, "begin", mapJavaTypeToString(JavaType::Double));
127+
assert(beginID != NULL);
128+
double begin = env->GetDoubleField(seq, beginID);
129+
130+
jfieldID endID = env->GetFieldID(seqCls, "end", mapJavaTypeToString(JavaType::Double));
131+
assert(endID != NULL);
132+
double end = env->GetDoubleField(seq, endID);
133+
134+
jfieldID stepID = env->GetFieldID(seqCls, "step", mapJavaTypeToString(JavaType::Double));
135+
assert(stepID != NULL);
136+
double step = env->GetDoubleField(seq, stepID);
137+
138+
index.idx.seq = af_make_seq(begin, end, step);
139+
} else {
140+
std::string getArrSig = generateFunctionSignature(JavaType::Long, {});
141+
jmethodID getArrId = env->GetMethodID(cls, "getArrRef", getArrSig.c_str());
142+
assert(getArrId != NULL);
143+
long arrRef = env->CallLongMethod(obj, getArrId);
144+
index.idx.arr = (af_array)arrRef;
145+
}
146+
return index;
147+
}
148+
97149
#define INSTANTIATE(type) \
98150
template jobject createJavaObject<type>(JNIEnv *, JavaObjects, type, type);
99151

‎src/java/java.h

Copy file name to clipboardExpand all lines: src/java/java.h
+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ enum class JavaObjects { FloatComplex, DoubleComplex };
1010
template <typename... Args>
1111
jobject createJavaObject(JNIEnv *env, JavaObjects objectType, Args... args);
1212

13+
af_index_t jIndexToCIndex(JNIEnv *env, jobject obj);
14+
1315
void throwArrayFireException(JNIEnv *env, const char *functionName,
1416
const char *file, const int line, const int code);
1517
} // namespace java

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.