Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c64a418

Browse filesBrowse files
ShadyBoukharyumar456
authored andcommitted
Added stdev functions.
1 parent 44a2b17 commit c64a418
Copy full SHA for c64a418

File tree

3 files changed

+97
-50
lines changed
Filter options

3 files changed

+97
-50
lines changed

‎com/arrayfire/Statistics.java

Copy file name to clipboardExpand all lines: com/arrayfire/Statistics.java
+32
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ public class Statistics extends ArrayFire {
3333

3434
static private native DoubleComplex afVarAllDoubleComplexWeighted(long ref, long weightsRef);
3535

36+
static private native long afStdev(long ref, int dim);
37+
38+
static private native double afStdevAll(long ref);
39+
40+
static private native FloatComplex afStdevAllFloatComplex(long ref);
41+
42+
static private native DoubleComplex afStdevAllDoubleComplex(long ref);
43+
3644
static public Array mean(final Array in, int dim) {
3745
return new Array(afMean(in.ref, dim));
3846
}
@@ -128,4 +136,28 @@ static public <T> T var(final Array in, final Array weights, Class<T> type) thro
128136
}
129137
throw new Exception("Unknown type");
130138
}
139+
140+
static public Array stdev(final Array in, int dim) {
141+
return new Array(afStdev(in.ref, dim));
142+
}
143+
144+
static public <T> T stdev(final Array in, Class<T> type) throws Exception {
145+
if (type == FloatComplex.class) {
146+
FloatComplex res = (FloatComplex)afStdevAllFloatComplex(in.ref);
147+
return type.cast(res);
148+
} else if (type == DoubleComplex.class) {
149+
DoubleComplex res = (DoubleComplex)afStdevAllDoubleComplex(in.ref);
150+
return type.cast(res);
151+
}
152+
153+
double res = afStdevAll(in.ref);
154+
if (type == Float.class) {
155+
return type.cast(Float.valueOf((float) res));
156+
} else if (type == Double.class) {
157+
return type.cast(Double.valueOf((double) res));
158+
} else if (type == Integer.class) {
159+
return type.cast(Integer.valueOf((int) res));
160+
}
161+
throw new Exception("Unknown type");
162+
}
131163
}

‎examples/HelloWorld.java

Copy file name to clipboardExpand all lines: examples/HelloWorld.java
+16-6
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,35 @@ public static void main(String[] args) {
1515

1616
System.out.println("Element-wise arithmetic");
1717
Arith.sin(b, a);
18-
System.out.println(a.toString("b"));
18+
System.out.println(b.toString("b"));
1919

2020
System.out.println("Fourier transform the result");
2121
Signal.fft(c, b);
22-
System.out.println(a.toString("c"));
22+
System.out.println(c.toString("c"));
2323

2424
System.out.println("Matmul b and c");
2525
Arith.mul(d, b, c);
26-
System.out.println(a.toString("d"));
26+
System.out.println(d.toString("d"));
2727

2828
System.out.println("Calculate weighted variance.");
2929
Array forVar = new Array();
3030
Array weights = new Array();
3131
Data.randn(forVar, new int[] { 5, 3 }, Array.DoubleType);
3232
Data.randn(weights, new int[] { 5, 3 }, Array.DoubleType);
33-
System.out.println(a.toString("forVar"));
33+
System.out.println(forVar.toString("forVar"));
3434

3535
double abc = Statistics.var(forVar, weights, Double.class);
3636
System.out.println(String.format("Variance is: %f", abc));
37+
forVar.close();
38+
weights.close();
39+
40+
System.out.println("Calculate standard deviation");
41+
Array forStdev = new Array();
42+
Data.randu(forStdev, new int[] {5, 3}, Array.DoubleType);
43+
System.out.println(forStdev.toString("forVar"));
44+
double stdev = Statistics.stdev(forStdev, Double.class);
45+
46+
System.out.println(String.format("Stdev is: %f", stdev));
3747

3848
System.out.println("Create a 2-by-3 matrix from host data");
3949
int[] dims = new int[] { 2, 3 };
@@ -49,13 +59,13 @@ public static void main(String[] args) {
4959
data[i] = (float) (tmp);
5060
}
5161
Array e = new Array(dims, data);
52-
System.out.println(a.toString("e"));
62+
System.out.println(e.toString("e"));
5363

5464
System.out.println("Add e and random array");
5565
Array randa = new Array();
5666
Data.randu(randa, dims, Array.FloatType);
5767
Arith.add(f, e, randa);
58-
System.out.println(a.toString("f"));
68+
System.out.println(f.toString("f"));
5969

6070
System.out.println("Copy result back to host.");
6171
float[] result = f.getFloatArray();

‎src/statistics.cpp

Copy file name to clipboardExpand all lines: src/statistics.cpp
+49-44
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,7 @@ BEGIN_EXTERN_C
55

66
#define STATISTICS_FUNC(FUNC) AF_MANGLE(Statistics, FUNC)
77

8-
#define INSTANTIATE_MEAN(jtype) \
9-
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afMeanAll##jtype)( \
10-
JNIEnv * env, jclass clazz, jlong ref) { \
11-
double real = 0, img = 0; \
12-
AF_CHECK(af_mean_all(&real, &img, ARRAY(ref))); \
13-
return java::createJavaObject(env, java::JavaObjects::jtype, real, img); \
14-
}
15-
16-
#define INSTANTIATE_WEIGHTED(jtype, Name, name) \
8+
#define INSTANTIATE_STAT_WEIGHTED_COMPLEX(jtype, Name, name) \
179
JNIEXPORT jobject JNICALL STATISTICS_FUNC(af##Name##All##jtype##Weighted)( \
1810
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef) { \
1911
double real = 0, img = 0; \
@@ -22,7 +14,7 @@ BEGIN_EXTERN_C
2214
return java::createJavaObject(env, java::JavaObjects::jtype, real, img); \
2315
}
2416

25-
#define INSTANTIATE_ALL_REAL_WEIGHTED(Name, name) \
17+
#define INSTANTIATE_STAT_ALL_WEIGHTED(Name, name) \
2618
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(af##Name##AllWeighted)( \
2719
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef) { \
2820
double ret = 0; \
@@ -31,7 +23,7 @@ BEGIN_EXTERN_C
3123
return (jdouble)ret; \
3224
}
3325

34-
#define INSTANTIATE_REAL_WEIGHTED(Name, name) \
26+
#define INSTANTIATE_STAT_WEIGHTED(Name, name) \
3527
JNIEXPORT jlong JNICALL STATISTICS_FUNC(af##Name##Weighted)( \
3628
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef, jint dim) { \
3729
af_array ret = 0; \
@@ -47,35 +39,39 @@ BEGIN_EXTERN_C
4739
return java::createJavaObject(env, java::JavaObjects::jtype, real, img); \
4840
}
4941

50-
#define INSTANTIATE_STAT(Name, name) \
51-
JNIEXPORT jlong JNICALL STATISTICS_FUNC(af##Name)(JNIEnv *env, jclass clazz, \
52-
jlong ref, jint dim) { \
53-
af_array ret = 0; \
54-
AF_CHECK(af_##name(&ret, ARRAY(ref), dim)); \
55-
return JLONG(ret); \
56-
}
42+
#define INSTANTIATE_STAT(Name, name) \
43+
JNIEXPORT jlong JNICALL STATISTICS_FUNC(af##Name)( \
44+
JNIEnv * env, jclass clazz, jlong ref, jint dim) { \
45+
af_array ret = 0; \
46+
AF_CHECK(af_##name(&ret, ARRAY(ref), dim)); \
47+
return JLONG(ret); \
48+
}
5749

58-
INSTANTIATE_STAT(Mean, mean)
59-
INSTANTIATE_STAT(Stdev, stdev)
50+
#define INSTANTIATE_STAT_ALL(Name, name) \
51+
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(af##Name##All)( \
52+
JNIEnv * env, jclass clazz, jlong ref) { \
53+
double ret = 0; \
54+
AF_CHECK(af_##name##_all(&ret, NULL, ARRAY(ref))); \
55+
return (jdouble)ret; \
56+
}
6057

61-
#define INSTANTIATE_STAT_ALL(Name, name) \
62-
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(af##Name##All)(JNIEnv *env, jclass clazz, \
63-
jlong ref) { \
64-
double ret = 0; \
65-
AF_CHECK(af_##name##_all(&ret, NULL, ARRAY(ref))); \
66-
return (jdouble)ret; \
67-
}
58+
#define INSTANTIATE_STAT_ALL_COMPLEX(Name, name, jtype) \
59+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(af##Name##All##jtype)( \
60+
JNIEnv * env, jclass clazz, jlong ref) { \
61+
double real = 0, img = 0; \
62+
AF_CHECK(af_##name##_all(&real, &img, ARRAY(ref))); \
63+
return java::createJavaObject(env, java::JavaObjects::jtype, real, img); \
64+
}
6865

66+
// Mean
67+
INSTANTIATE_STAT(Mean, mean)
6968
INSTANTIATE_STAT_ALL(Mean, mean)
70-
71-
INSTANTIATE_MEAN(FloatComplex)
72-
INSTANTIATE_MEAN(DoubleComplex)
73-
INSTANTIATE_ALL_REAL_WEIGHTED(Mean, mean)
74-
INSTANTIATE_REAL_WEIGHTED(Mean, mean)
75-
INSTANTIATE_WEIGHTED(FloatComplex, Mean, mean)
76-
INSTANTIATE_WEIGHTED(DoubleComplex, Mean, mean)
77-
78-
#undef INSTANTIATE_MEAN
69+
INSTANTIATE_STAT_ALL_COMPLEX(Mean, mean, FloatComplex)
70+
INSTANTIATE_STAT_ALL_COMPLEX(Mean, mean, DoubleComplex)
71+
INSTANTIATE_STAT_ALL_WEIGHTED(Mean, mean)
72+
INSTANTIATE_STAT_WEIGHTED(Mean, mean)
73+
INSTANTIATE_STAT_WEIGHTED_COMPLEX(FloatComplex, Mean, mean)
74+
INSTANTIATE_STAT_WEIGHTED_COMPLEX(DoubleComplex, Mean, mean)
7975

8076
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afVar)(JNIEnv *env, jclass clazz,
8177
jlong ref, jboolean isBiased,
@@ -93,17 +89,26 @@ JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afVarAll)(JNIEnv *env, jclass clazz,
9389
return (jdouble)ret;
9490
}
9591

92+
// Variance
9693
INSTANTIATE_VAR(FloatComplex)
9794
INSTANTIATE_VAR(DoubleComplex)
98-
INSTANTIATE_REAL_WEIGHTED(Var, var)
99-
INSTANTIATE_ALL_REAL_WEIGHTED(Var, var)
100-
INSTANTIATE_WEIGHTED(FloatComplex, Var, var)
101-
INSTANTIATE_WEIGHTED(DoubleComplex, Var, var)
95+
INSTANTIATE_STAT_WEIGHTED(Var, var)
96+
INSTANTIATE_STAT_ALL_WEIGHTED(Var, var)
97+
INSTANTIATE_STAT_WEIGHTED_COMPLEX(FloatComplex, Var, var)
98+
INSTANTIATE_STAT_WEIGHTED_COMPLEX(DoubleComplex, Var, var)
99+
100+
// Standard dev
101+
INSTANTIATE_STAT(Stdev, stdev)
102+
INSTANTIATE_STAT_ALL(Stdev, stdev)
103+
INSTANTIATE_STAT_ALL_COMPLEX(Stdev, stdev, FloatComplex)
104+
INSTANTIATE_STAT_ALL_COMPLEX(Stdev, stdev, DoubleComplex)
102105

103106
#undef INSTANTIATE_VAR
104-
#undef INSTANTIATE_MEAN
105-
#undef INSTANTIATE_WEIGHTED
106-
#undef INSTANTIATE_REAL_WEIGHTED
107-
#undef INSTANTIATE_ALL_REAL_WEIGHTED
107+
#undef INSTANTIATE_STAT_WEIGHTED_COMPLEX
108+
#undef INSTANTIATE_STAT_WEIGHTED
109+
#undef INSTANTIATE_STAT_ALL_WEIGHTED
110+
#undef INSTANTIATE_STAT
111+
#undef INSTANTIATE_STAT_ALL
112+
#undef INSTANTIATE_STAT_ALL_COMPLEX
108113

109114
END_EXTERN_C

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.