forked from yuemingl/SymJava
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBenchmarkSqrt.java
More file actions
107 lines (94 loc) · 2.59 KB
/
BenchmarkSqrt.java
File metadata and controls
107 lines (94 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package symjava.examples;
import static symjava.math.SymMath.pow;
import static symjava.symbolic.Symbol.x;
import java.util.ArrayList;
import symjava.bytecode.BytecodeFunc;
import symjava.bytecode.BytecodeVecFunc;
import symjava.symbolic.Expr;
import symjava.symbolic.Func;
import symjava.symbolic.utils.JIT;
public class BenchmarkSqrt {
public static double factorial(int n) {
double rlt = 1;
for(int i=1; i<=n; i++)
rlt *= i;
return rlt;
}
public static void main(String[] args) {
test();
testBatchEval();
}
public static void test() {
int n = 9;
Expr expr = 0;
Expr term;
ArrayList<Expr> exprs = new ArrayList<Expr>();
for(int i=1; i<(n+1); i++) {
term = (pow(x, 1.0/i));
//term = (sqrt(x, i));
//System.out.println(term);
expr = expr + term;
exprs.add(expr);
}
final ArrayList<BytecodeFunc> funcs = new ArrayList<BytecodeFunc>();
for(int i=0; i<n; i++) {
Func func = new Func("func"+i, exprs.get(i));
BytecodeFunc bfunc = func.toBytecodeFunc();
System.out.println(bfunc.apply(0.1));
funcs.add(bfunc);
}
int N=10000000;
double xx = 0.1;
double out = 0.0;
for(int i=0; i<funcs.size(); i++) {
long begin = System.currentTimeMillis();
for(int j=0; j<N; j++) {
xx += 1e-15;
out += funcs.get(i).apply(xx);
}
long end = System.currentTimeMillis();
System.out.println("Time: "+((end-begin)/1000.0)+" expr="+exprs.get(i));
}
System.out.println("Test Value="+out);
}
public static void testBatchEval() {
int n = 9;
Expr expr = 0;
Expr term;
ArrayList<Expr> exprs = new ArrayList<Expr>();
for(int i=1; i<(n+1); i++) {
term = (pow(x, 1.0/i));
//term = (sqrt(x, i));
//System.out.println(term);
expr = expr + term;
exprs.add(expr);
}
ArrayList<BytecodeVecFunc> funcs = new ArrayList<BytecodeVecFunc>();
int batchLen = 10000;
double[] outAry = new double[batchLen];
double[] args = new double[batchLen];
for(int i=0; i<n; i++) {
Func func = new Func("func"+i, exprs.get(i));
BytecodeVecFunc bfunc = JIT.compileVecFunc(func.args(), func);
funcs.add(bfunc);
}
int N=10000000/batchLen;
double out = 0.0;
double xx = 0.1;
for(int i=0; i<funcs.size(); i++) {
long begin = System.currentTimeMillis();
for(int j=0; j<N; j++) {
for(int k=0; k<batchLen; k++) {
xx += 1e-15;
args[k] = xx;
}
funcs.get(i).apply(outAry, 0, args);
for(int k=0; k<batchLen; k++)
out += outAry[k];
}
long end = System.currentTimeMillis();
System.out.println("Time: "+((end-begin)/1000.0)+" expr="+exprs.get(i));
}
System.out.println("Test Value="+out);
}
}