forked from yuemingl/SymJava
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNewtonOptimization.java
More file actions
62 lines (55 loc) · 1.72 KB
/
NewtonOptimization.java
File metadata and controls
62 lines (55 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
package symjava.examples;
import symjava.matrix.SymMatrix;
import symjava.matrix.SymVector;
import symjava.numeric.NumMatrix;
import symjava.numeric.NumVector;
import symjava.relational.Eq;
import symjava.symbolic.Expr;
import symjava.symbolic.Symbol;
import Jama.Matrix;
public class NewtonOptimization {
public static void solve(Eq eq, double[] init, int maxIter, double eps, boolean dislpayOnly) {
if(!Symbol.C0.symEquals(eq.rhs)) {
System.out.println("The right hand side of the equation must be 0.");
return;
}
Expr[] unknowns = eq.getUnknowns();
int n = unknowns.length;
//Construct Hessian Matrix
SymVector grad = new SymVector(n);
SymMatrix hess = new SymMatrix(n, n);
Expr L = eq.lhs;
for(int i=0; i<n; i++) {
grad[i] = L.diff(unknowns[i]);
for(int j=0; j<n; j++) {
Expr df = grad[i].diff(unknowns[j]);
hess[i][j] = df;
}
}
System.out.println("Hessian Matrix = ");
System.out.println(hess);
System.out.println("Grident = ");
System.out.println(grad);
if(dislpayOnly) return;
//Convert symbolic staff to Bytecode staff to speedup evaluation
NumMatrix NH = new NumMatrix(hess, unknowns);
NumVector NG = new NumVector(grad, unknowns);
System.out.println("Iterativly sovle ... ");
for(int i=0; i<maxIter; i++) {
//Use JAMA to solve the system
Matrix A = new Matrix(NH.eval(init));
Matrix b = new Matrix(NG.eval(init), NG.dim());
Matrix x = A.solve(b); //Lease Square solution
for(int j=0; j<init.length; j++) {
System.out.print(String.format("%s=%.5f",unknowns[j], init[j])+" ");
}
System.out.println();
if(x.norm2() < eps)
break;
//Update initial guess
for(int j=0; j<init.length; j++) {
init[j] = init[j] - x.get(j, 0);
}
}
}
}