Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1962175

Browse filesBrowse files
ShadyBoukharyumar456
authored andcommitted
Added statistics function: af_var_*
1 parent 1f4cd41 commit 1962175
Copy full SHA for 1962175

File tree

3 files changed

+202
-95
lines changed
Filter options

3 files changed

+202
-95
lines changed

‎com/arrayfire/Statistics.java

Copy file name to clipboardExpand all lines: com/arrayfire/Statistics.java
+65Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,22 @@ public class Statistics extends ArrayFire {
1717

1818
static private native DoubleComplex afMeanAllDoubleComplexWeighted(long ref, long weightsRef);
1919

20+
static private native long afVar(long ref, boolean isBiased, int dim);
21+
22+
static private native long afVarWeighted(long ref, long weightsRef, int dim);
23+
24+
static private native double afVarAll(long ref, boolean isBiased);
25+
26+
static private native double afVarAllWeighted(long ref, long weightsRef);
27+
28+
static private native FloatComplex afVarAllFloatComplex(long ref, boolean isBiased);
29+
30+
static private native DoubleComplex afVarAllDoubleComplex(long ref, boolean isBiased);
31+
32+
static private native FloatComplex afVarAllFloatComplexWeighted(long ref, long weightsRef);
33+
34+
static private native DoubleComplex afVarAllDoubleComplexWeighted(long ref, long weightsRef);
35+
2036
static public Array mean(final Array in, int dim) {
2137
return new Array(afMean(in.ref, dim));
2238
}
@@ -65,4 +81,53 @@ static public <T> T mean(final Array in, final Array weights, Class<T> type) thr
6581
}
6682
throw new Exception("Unknown type");
6783
}
84+
85+
static public Array var(final Array in, boolean isBiased, int dim) {
86+
return new Array(afVar(in.ref, isBiased, dim));
87+
}
88+
89+
static public Array var(final Array in, final Array weights, int dim) {
90+
return new Array(afVarWeighted(in.ref, weights.ref, dim));
91+
}
92+
93+
static public <T> T var(final Array in, boolean isBiased, Class<T> type) throws Exception {
94+
if (type == FloatComplex.class) {
95+
FloatComplex res = (FloatComplex) afVarAllFloatComplex(in.ref, isBiased);
96+
return type.cast(res);
97+
} else if (type == DoubleComplex.class) {
98+
DoubleComplex res = (DoubleComplex) afVarAllDoubleComplex(in.ref, isBiased);
99+
return type.cast(res);
100+
}
101+
102+
double res = afVarAll(in.ref, isBiased);
103+
if (type == Float.class) {
104+
return type.cast(Float.valueOf((float) res));
105+
} else if (type == Double.class) {
106+
return type.cast(Double.valueOf((double) res));
107+
} else if (type == Integer.class) {
108+
return type.cast(Integer.valueOf((int) res));
109+
}
110+
throw new Exception("Unknown type");
111+
}
112+
113+
static public <T> T var(final Array in, final Array weights, Class<T> type) throws Exception {
114+
if (type == FloatComplex.class) {
115+
FloatComplex res = (FloatComplex) afVarAllFloatComplexWeighted(in.ref, weights.ref);
116+
return type.cast(res);
117+
} else if (type == DoubleComplex.class) {
118+
System.out.println(Long.toString(weights.ref));
119+
DoubleComplex res = (DoubleComplex) afVarAllDoubleComplexWeighted(in.ref, weights.ref);
120+
return type.cast(res);
121+
}
122+
123+
double res = afVarAllWeighted(in.ref, weights.ref);
124+
if (type == Float.class) {
125+
return type.cast(Float.valueOf((float) res));
126+
} else if (type == Double.class) {
127+
return type.cast(Double.valueOf((double) res));
128+
} else if (type == Integer.class) {
129+
return type.cast(Integer.valueOf((int) res));
130+
}
131+
throw new Exception("Unknown type");
132+
}
68133
}

‎examples/HelloWorld.java

Copy file name to clipboardExpand all lines: examples/HelloWorld.java
+54-55Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,72 +8,71 @@ public static void main(String[] args) {
88
Array a = new Array(), b = new Array(), c = new Array(), d = new Array();
99
Array f = new Array();
1010
try {
11-
Util.info();
12-
System.out.println("Create a 5-by-3 matrix of random floats on the GPU");
13-
Data.randu(a, new int[] {5, 3}, Array.FloatType);
14-
a.print("a");
11+
Util.info();
12+
System.out.println("Create a 5-by-3 matrix of random floats on the GPU");
13+
Data.randu(a, new int[] { 5, 3 }, Array.FloatType);
14+
a.print("a");
1515

16-
System.out.println("Element-wise arithmetic");
17-
Arith.sin(b, a);
18-
b.print("b");
16+
System.out.println("Element-wise arithmetic");
17+
Arith.sin(b, a);
18+
b.print("b");
1919

20-
System.out.println("Fourier transform the result");
21-
Signal.fft(c, b);
22-
c.print("c");
20+
System.out.println("Fourier transform the result");
21+
Signal.fft(c, b);
22+
c.print("c");
2323

24-
System.out.println("Matmul b and c");
25-
Arith.mul(d, b, c);
26-
d.print("d");
24+
System.out.println("Matmul b and c");
25+
Arith.mul(d, b, c);
26+
d.print("d");
2727

28+
System.out.println("Calculate weighted variance.");
29+
Array forVar = new Array();
30+
Array weights = new Array();
31+
Data.randn(forVar, new int[] { 5, 3 }, Array.DoubleType);
32+
Data.randn(weights, new int[] { 5, 3 }, Array.DoubleType);
33+
forVar.print("forVar");
2834

29-
Array forMean = new Array();
30-
Array weights = new Array();
31-
Data.randn(forMean, new int[] {3, 3}, Array.FloatComplexType);
32-
Data.randn(weights, new int[] {3, 3}, Array.FloatType);
33-
forMean.print("forMean");
35+
double abc = Statistics.mean(forVar, weights, Double.class);
36+
System.out.println(String.format("Variance is: %f", abc));
3437

35-
FloatComplex abc = Statistics.mean(forMean, weights, FloatComplex.class);
36-
System.out.println(String.format("Mean is: %f and %f", abc.real(), abc.imag()));
38+
System.out.println("Create a 2-by-3 matrix from host data");
39+
int[] dims = new int[] { 2, 3 };
40+
int total = 1;
41+
for (int dim : dims) {
42+
total *= dim;
43+
}
44+
float[] data = new float[total];
45+
Random rand = new Random();
3746

38-
System.out.println("Create a 2-by-3 matrix from host data");
39-
int[] dims = new int[] { 2, 3 };
40-
int total = 1;
41-
for (int dim : dims) {
42-
total *= dim;
43-
}
44-
float[] data = new float[total];
45-
Random rand = new Random();
46-
47-
for (int i = 0; i < total; i++) {
48-
double tmp = Math.ceil(rand.nextDouble() * 10) / 10;
49-
data[i] = (float) (tmp);
50-
}
51-
Array e = new Array(dims, data);
52-
e.print("e");
53-
54-
System.out.println("Add e and random array");
55-
Array randa = new Array();
56-
Data.randu(randa, dims, Array.FloatType);
57-
Arith.add(f, e, randa);
58-
f.print("f");
47+
for (int i = 0; i < total; i++) {
48+
double tmp = Math.ceil(rand.nextDouble() * 10) / 10;
49+
data[i] = (float) (tmp);
50+
}
51+
Array e = new Array(dims, data);
52+
e.print("e");
5953

54+
System.out.println("Add e and random array");
55+
Array randa = new Array();
56+
Data.randu(randa, dims, Array.FloatType);
57+
Arith.add(f, e, randa);
58+
f.print("f");
6059

61-
System.out.println("Copy result back to host.");
62-
float[] result = f.getFloatArray();
63-
for (int i = 0; i < dims[0]; i++) {
64-
for (int y = 0; y < dims[1]; y++) {
65-
System.out.print(result[y * dims[0] + i] + " ");
66-
}
67-
System.out.println();
60+
System.out.println("Copy result back to host.");
61+
float[] result = f.getFloatArray();
62+
for (int i = 0; i < dims[0]; i++) {
63+
for (int y = 0; y < dims[1]; y++) {
64+
System.out.print(result[y * dims[0] + i] + " ");
6865
}
69-
a.close();
70-
b.close();
71-
c.close();
72-
d.close();
73-
e.close();
74-
f.close();
66+
System.out.println();
67+
}
68+
a.close();
69+
b.close();
70+
c.close();
71+
d.close();
72+
e.close();
73+
f.close();
7574
} catch (Exception ex) {
76-
ex.printStackTrace();
75+
ex.printStackTrace();
7776
}
7877
}
7978
}

‎src/statistics.cpp

Copy file name to clipboardExpand all lines: src/statistics.cpp
+83-40Lines changed: 83 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,64 +5,107 @@ BEGIN_EXTERN_C
55

66
#define STATISTICS_FUNC(FUNC) AF_MANGLE(Statistics, FUNC)
77

8+
#define INSTANTIATE_MEAN(jtype, param) \
9+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afMeanAll##jtype)( \
10+
JNIEnv * env, jclass clazz, jlong ref) { \
11+
double real = 0, img = 0; \
12+
AF_CHECK(af_mean_all(&real, &img, ARRAY(ref))); \
13+
jclass cls = env->FindClass("com/arrayfire/" #jtype); \
14+
jmethodID id = env->GetMethodID(cls, "<init>", "(" #param ")V"); \
15+
jobject obj = env->NewObject(cls, id, real, img); \
16+
return obj; \
17+
}
18+
19+
#define INSTANTIATE_WEIGHTED(jtype, param, Name, name) \
20+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(af##Name##All##jtype##Weighted)( \
21+
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef) { \
22+
double real = 0, img = 0; \
23+
AF_CHECK( \
24+
af_##name##_all_weighted(&real, &img, ARRAY(ref), ARRAY(weightsRef))); \
25+
jclass cls = env->FindClass("com/arrayfire/" #jtype); \
26+
jmethodID id = env->GetMethodID(cls, "<init>", "(" #param ")V"); \
27+
jobject obj = env->NewObject(cls, id, real, img); \
28+
return obj; \
29+
}
30+
31+
#define INSTANTIATE_ALL_REAL_WEIGHTED(Name, name) \
32+
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(af##Name##AllWeighted)( \
33+
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef) { \
34+
double ret = 0; \
35+
AF_CHECK( \
36+
af_##name##_all_weighted(&ret, NULL, ARRAY(ref), ARRAY(weightsRef))); \
37+
return (jdouble)ret; \
38+
}
39+
40+
#define INSTANTIATE_REAL_WEIGHTED(Name, name) \
41+
JNIEXPORT jlong JNICALL STATISTICS_FUNC(af##Name##Weighted)( \
42+
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef, jint dim) { \
43+
af_array ret = 0; \
44+
AF_CHECK(af_##name##_weighted(&ret, ARRAY(ref), ARRAY(weightsRef), dim)); \
45+
return JLONG(ret); \
46+
}
47+
48+
#define INSTANTIATE_VAR(jtype, param) \
49+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afVarAll##jtype)( \
50+
JNIEnv * env, jclass clazz, jlong ref, jboolean isBiased) { \
51+
double real = 0, img = 0; \
52+
AF_CHECK(af_var_all(&real, &img, ARRAY(ref), isBiased)); \
53+
jclass cls = env->FindClass("com/arrayfire/" #jtype); \
54+
jmethodID id = env->GetMethodID(cls, "<init>", "(" #param ")V"); \
55+
jobject obj = env->NewObject(cls, id, real, img); \
56+
return obj; \
57+
}
58+
859
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afMean)(JNIEnv *env, jclass clazz,
960
jlong ref, jint dim) {
1061
af_array ret = 0;
1162
AF_CHECK(af_mean(&ret, ARRAY(ref), dim));
1263
return JLONG(ret);
1364
}
1465

15-
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afMeanWeighted)(JNIEnv *env,
16-
jclass clazz, jlong ref,
17-
jlong weightsRef,
18-
jint dim) {
19-
af_array ret = 0;
20-
AF_CHECK(af_mean_weighted(&ret, ARRAY(ref), ARRAY(weightsRef), dim));
21-
return JLONG(ret);
22-
}
23-
2466
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afMeanAll)(JNIEnv *env, jclass clazz,
2567
jlong ref) {
2668
double ret = 0;
2769
AF_CHECK(af_mean_all(&ret, NULL, ARRAY(ref)));
2870
return (jdouble)ret;
2971
}
3072

31-
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afMeanAllWeighted)(JNIEnv *env,
32-
jclass clazz,
33-
jlong ref,
34-
jlong weightsRef) {
73+
INSTANTIATE_MEAN(FloatComplex, FF)
74+
INSTANTIATE_MEAN(DoubleComplex, DD)
75+
INSTANTIATE_ALL_REAL_WEIGHTED(Mean, mean)
76+
INSTANTIATE_REAL_WEIGHTED(Mean, mean)
77+
INSTANTIATE_WEIGHTED(FloatComplex, FF, Mean, mean)
78+
INSTANTIATE_WEIGHTED(DoubleComplex, DD, Mean, mean)
79+
80+
#undef INSTANTIATE_MEAN
81+
82+
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afVar)(JNIEnv *env, jclass clazz,
83+
jlong ref, jboolean isBiased,
84+
jint dim) {
85+
af_array ret = 0;
86+
AF_CHECK(af_var(&ret, ARRAY(ref), isBiased, dim));
87+
return JLONG(ret);
88+
}
89+
90+
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afVarAll)(JNIEnv *env, jclass clazz,
91+
jlong ref,
92+
jboolean isBiased) {
3593
double ret = 0;
36-
AF_CHECK(af_mean_all_weighted(&ret, NULL, ARRAY(ref), ARRAY(weightsRef)));
94+
AF_CHECK(af_var_all(&ret, NULL, ARRAY(ref), isBiased));
3795
return (jdouble)ret;
3896
}
3997

40-
#define INSTANTIATE_MEAN(jtype, param) \
41-
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afMeanAll##jtype)( \
42-
JNIEnv * env, jclass clazz, jlong ref) { \
43-
double real = 0, img = 0; \
44-
AF_CHECK(af_mean_all(&real, &img, ARRAY(ref))); \
45-
jclass cls = env->FindClass("com/arrayfire/" #jtype); \
46-
jmethodID id = env->GetMethodID(cls, "<init>", "(" #param ")V"); \
47-
jobject obj = env->NewObject(cls, id, real, img); \
48-
return obj; \
49-
}
50-
51-
#define INSTANTIATE_MEAN_WEIGHTED(jtype, param) \
52-
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afMeanAll##jtype##Weighted)( \
53-
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef) { \
54-
double real = 0, img = 0; \
55-
AF_CHECK( \
56-
af_mean_all_weighted(&real, &img, ARRAY(ref), ARRAY(weightsRef))); \
57-
jclass cls = env->FindClass("com/arrayfire/" #jtype); \
58-
jmethodID id = env->GetMethodID(cls, "<init>", "(" #param ")V"); \
59-
jobject obj = env->NewObject(cls, id, real, img); \
60-
return obj; \
61-
}
98+
INSTANTIATE_VAR(FloatComplex, FF)
99+
INSTANTIATE_VAR(DoubleComplex, DD)
100+
INSTANTIATE_REAL_WEIGHTED(Var, var)
101+
INSTANTIATE_ALL_REAL_WEIGHTED(Var, var)
102+
INSTANTIATE_WEIGHTED(FloatComplex, FF, Var, var)
103+
INSTANTIATE_WEIGHTED(DoubleComplex, DD, Var, var)
62104

63-
INSTANTIATE_MEAN(FloatComplex, FF)
64-
INSTANTIATE_MEAN(DoubleComplex, DD)
65-
INSTANTIATE_MEAN_WEIGHTED(FloatComplex, FF)
66-
INSTANTIATE_MEAN_WEIGHTED(DoubleComplex, DD)
105+
#undef INSTANTIATE_VAR
106+
#undef INSTANTIATE_MEAN
107+
#undef INSTANTIATE_WEIGHTED
108+
#undef INSTANTIATE_REAL_WEIGHTED
109+
#undef INSTANTIATE_ALL_REAL_WEIGHTED
67110

68111
END_EXTERN_C

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.