11package lambdacloud .test ;
22
3+ import static lambdacloud .core .LambdaCloud .CPU ;
34
45import java .util .HashMap ;
56import java .util .Map ;
67
7- import lambdacloud .core .CloudConfig ;
88import lambdacloud .core .CloudSD ;
99import lambdacloud .core .Session ;
10- import lambdacloud .core .graph .GraphBuilder ;
11- import lambdacloud .core .graph .Node ;
1210import lambdacloud .core .lang .LCDevice ;
11+ import lambdacloud .core .lang .LCReturn ;
1312import symjava .bytecode .BytecodeBatchFunc ;
1413import symjava .matrix .SymMatrix ;
1514import symjava .matrix .SymVector ;
2019
2120public class TestMatrix {
2221 public static void main (String [] args ) {
23- // test1();
24- // test2();
25- // test3();
26- // test4();
27- // test5();
28- // test6();
29- test7 ();
22+ testBasic1 ();
23+ testBasic2 ();
24+ testBasic3 ();
25+ testBasic4 ();
26+ testConcat1 ();
27+ testConcat2 ();
28+ testMatrixSplit1 ();
29+ testMatrixSplit2 ();
30+ testMatrixSplit3 ();
3031 }
31- public static void test1 () {
32- // TODO Auto-generated method stub
32+
33+ public static boolean assertEqual (double [] a , double [] b ) {
34+ if (a .length != b .length ) {
35+ System .err .println ("Failed! a.length != b.length: " +a .length +" != " +b .length );
36+ return false ;
37+ }
38+ for (int i =0 ; i <a .length ; i ++) {
39+ if (Math .abs (a [i ]-b [i ])>1e-8 ) {
40+ System .err .println ("Failed! a[" +i +"] != b[" +i +"]: " +a [i ]+" != " +b [i ]);
41+ return false ;
42+ }
43+ }
44+ System .out .println ("Passed!" );
45+ return true ;
46+ }
47+
48+ public static void testBasic1 () {
3349 Matrix A = new Matrix ("A" ,3 ,3 );
50+ BytecodeBatchFunc fun = CompileUtils .compileVec (new LCReturn (A ));
51+ /**
52+ * 1 2 3
53+ * 4 5 6
54+ * 7 8 9
55+ */
56+ double [] data_A = new double [] {1 ,4 ,7 ,2 ,5 ,8 ,3 ,6 ,9 }; // Column-wise
57+ double [] outAry = new double [9 ];
58+ fun .apply (outAry , 0 , data_A );
59+ assertEqual (new double []{1 ,4 ,7 ,2 ,5 ,8 ,3 ,6 ,9 }, outAry );
60+ }
61+
62+ public static void testBasic2 () {
3463 Vector x = new Vector ("x" ,3 );
35-
36- //CompileUtils.compile("test1", A, A);
37- //CompileUtils.compile("test2", x, x);
38-
39- //BytecodeFunc fun = CompileUtils.compile("test2", A*x, A, x);
40- //double ret = fun.apply(new double[9]);
41- //System.out.println(ret);
42-
43- BytecodeBatchFunc fun = CompileUtils .compileVec (A *x , A , x );
44- double [] outAry = new double [4 ];
45- double [] data_A = new double [] {1 ,4 ,7 ,2 ,5 ,8 ,3 ,6 ,9 }; //columewise
64+ BytecodeBatchFunc fun = CompileUtils .compileVec (new LCReturn (x ));
65+ double [] data_x = new double [] {1 ,2 ,3 };
66+ double [] outAry = new double [3 ];
67+ fun .apply (outAry , 0 , data_x );
68+ assertEqual (new double []{1 ,2 ,3 }, outAry );
69+ }
70+
71+ public static void testBasic3 () {
72+ Matrix A = new Matrix ("A" ,3 ,3 );
73+ Vector x = new Vector ("x" ,3 );
74+ BytecodeBatchFunc fun = CompileUtils .compileVec (new LCReturn (A *x ));
75+ /**
76+ * 1 2 3
77+ * 4 5 6
78+ * 7 8 9
79+ */
80+ double [] data_A = new double [] {1 ,4 ,7 ,2 ,5 ,8 ,3 ,6 ,9 }; // Column-wise
4681 double [] data_x = new double [] {1 ,2 ,3 };
47- fun . apply ( outAry , 1 , data_A , data_x ) ;
48- for ( double i : outAry )
49- System . out . println ( i );
82+ double [] outAry = new double [ 3 ] ;
83+ fun . apply ( outAry , 0 , data_A , data_x );
84+ assertEqual ( new double []{ 14 , 32 , 50 }, outAry );
5085 }
5186
52- public static void test2 () {
87+ public static void testBasic4 () {
5388 Vector x = new Vector ("x" ,3 );
5489 Vector y = new Vector ("y" ,3 );
5590
56- BytecodeBatchFunc fun = CompileUtils .compileVec (x +y , x , y );
91+ BytecodeBatchFunc fun = CompileUtils .compileVec (new LCReturn ( x +y ) , x , y );
5792 double [] outAry = new double [4 ];
58- double [] data_x = new double [] {1 ,2 ,3 }; //columewise
93+ double [] data_x = new double [] {1 ,2 ,3 };
5994 double [] data_y = new double [] {1 ,2 ,3 };
60- fun .apply (outAry , 0 , data_x , data_y );
61- for (double i : outAry )
62- System .out .println (i );
95+ fun .apply (outAry , 1 , data_x , data_y ); //output at position 1
96+ assertEqual (new double []{0 ,2 ,4 ,6 }, outAry );
6397 }
6498
65- public static void test3 () {
99+ public static void testConcat1 () {
66100 Vector x = new Vector ("x" ,3 );
67101 Vector y = new Vector ("y" ,2 );
68- Vector z = new Vector ("z" ,3 );
102+ Vector z = new Vector ("z" ,4 );
69103
70- BytecodeBatchFunc fun = CompileUtils .compileVec (new Concat (x ,y ,z ), x , y , z );
104+ BytecodeBatchFunc fun = CompileUtils .compileVec (new LCReturn ( new Concat (x ,y ,z ) ), x , y , z );
71105 double [] outAry = new double [9 ];
72- double [] data_x = new double [] {1 ,2 ,3 }; //columewise
106+ double [] data_x = new double [] {1 ,2 ,3 };
73107 double [] data_y = new double [] {4 ,5 };
74- double [] data_z = new double [] {6 ,7 ,8 };
108+ double [] data_z = new double [] {6 ,7 ,8 , 9 };
75109 fun .apply (outAry , 0 , data_x , data_y , data_z );
76- for (double i : outAry )
77- System .out .println (i );
110+ assertEqual (new double []{1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 }, outAry );
78111 }
79112
80- public static void test4 () {
113+ public static void testConcat2 () {
81114 Vector x = new Vector ("x" ,3 );
82115 Vector y = new Vector ("y" ,2 );
83116 Vector z = new Vector ("z" ,5 );
84117
85- BytecodeBatchFunc fun = CompileUtils .compileVec (new Concat (x ,y )+z , x , y , z );
86- double [] outAry = new double [9 ];
87- double [] data_x = new double [] {1 ,2 ,3 }; //columewise
118+ BytecodeBatchFunc fun = CompileUtils .compileVec (new LCReturn ( new Concat (x ,y )+z ) , x , y , z );
119+ double [] outAry = new double [5 ];
120+ double [] data_x = new double [] {1 ,2 ,3 };
88121 double [] data_y = new double [] {4 ,5 };
89122 double [] data_z = new double [] {1 ,2 ,3 ,4 ,5 };
90123 fun .apply (outAry , 0 , data_x , data_y , data_z );
91- for (double i : outAry )
92- System .out .println (i );
124+ assertEqual (new double []{2 ,4 ,6 ,8 ,10 }, outAry );
93125 }
94- public static void test5 () {
126+
127+ public static void testMatrixSplit1 () {
95128 int dim = 4 ;
96129 Matrix A = new Matrix ("A" , dim , dim );
97130 Vector x = new Vector ("x" , dim );
@@ -101,15 +134,15 @@ public static void test5() {
101134 SymVector xx = x .split (2 );
102135 SymVector yy = (SymVector )(AA *xx );
103136 System .out .println (yy );
104- yy [0 ].runOn (new LCDevice ("/cpu:0" ));
105- yy [1 ].runOn (new LCDevice ("/cpu:0" ));
137+ yy [0 ].runOn (new LCDevice (0 ));
138+ yy [1 ].runOn (new LCDevice (1 ));
106139
107140 Expr res = new Concat (yy [0 ],yy [1 ])+y0 ;
108141
109142 System .out .println (res );
110- //BytecodeBatchFunc fun = CompileUtils.compileVec(res, A,x,y0,AA[0][0],AA[1][0],xx[0],xx[1]);
111- BytecodeBatchFunc fun = CompileUtils .compileVec (res );
112- //void apply(double[] output, int outPos, double[] A_1_1, double[] A_1_0, double[] A_0_1, double[] A_0_0, double[] x_0, double[] x_1, double[] y0 );
143+ //Doesn't work
144+ // BytecodeBatchFunc fun = CompileUtils.compileVec(new LCReturn( res), A,x,y0,AA[0][0],AA[1][0],xx[0],xx[1] );
145+ BytecodeBatchFunc fun = CompileUtils . compileVec ( new LCReturn ( res ) );
113146/*
1141471 2 3 4 0 1 13
1151481 2 1 3 * 1 + 2 = 9
@@ -124,16 +157,12 @@ public static void test5() {
124157 double [] data_x_0 = new double [] {0 ,1 };
125158 double [] data_x_1 = new double [] {2 ,1 };
126159 double [] data_y0 = new double [] {1 ,2 ,3 ,4 };
160+ //void apply(double[] output, int outPos, double[] A_1_1, double[] A_1_0, double[] A_0_1, double[] A_0_0, double[] x_0, double[] x_1, double[] y0);
127161 fun .apply (outAry , 0 , data_A_11 , data_A_10 , data_A_01 , data_A_00 , data_x_0 , data_x_1 , data_y0 );
128- for (double i : outAry )
129- System .out .println (i );
130- //13.0
131- //9.0
132- //10.0
133- //13.0
162+ assertEqual (new double []{13 ,9 ,10 ,13 }, outAry );
134163 }
135164
136- public static void test6 () {
165+ public static void testMatrixSplit2 () {
137166 int dim = 4 ;
138167 Matrix A = new Matrix ("A" , dim , dim );
139168 Vector x = new Vector ("x" , dim );
@@ -143,8 +172,8 @@ public static void test6() {
143172 SymVector xx = x .split (2 );
144173 SymVector yy = (SymVector )(AA *xx );
145174 System .out .println (yy );
146- yy [0 ].runOn (new LCDevice ("/cpu:0" ));
147- yy [1 ].runOn (new LCDevice ("/cpu:0" ));
175+ yy [0 ].runOn (new LCDevice (0 ));
176+ yy [1 ].runOn (new LCDevice (1 ));
148177
149178 Expr res = new Concat (yy [0 ],yy [1 ])+y0 ;
150179
@@ -159,12 +188,12 @@ public static void test6() {
159188 dict .put (x .toString (), new double []{0 ,1 ,2 ,1 });
160189 dict .put (y0 .toString (), new double []{1 ,2 ,3 ,4 });
161190
162- //test
163- //those parameters should be able automatically generated accroding to the definition of AA and xx
164- double [] data_A_11 = new double [] {2 ,1 ,1 ,4 }; //columewise
165- double [] data_A_10 = new double [] {1 ,2 ,2 ,3 }; //columewise
166- double [] data_A_01 = new double [] {3 ,1 ,4 ,3 }; //columewise
167- double [] data_A_00 = new double [] {1 ,1 ,2 ,2 }; //columewise
191+ //those parameters should be able automatically generated according to the definition of AA and xx
192+ //see testMatrixSplit3()
193+ double [] data_A_11 = new double [] {2 ,1 ,1 ,4 };
194+ double [] data_A_10 = new double [] {1 ,2 ,2 ,3 };
195+ double [] data_A_01 = new double [] {3 ,1 ,4 ,3 };
196+ double [] data_A_00 = new double [] {1 ,1 ,2 ,2 };
168197 double [] data_x_0 = new double [] {0 ,1 };
169198 double [] data_x_1 = new double [] {2 ,1 };
170199 dict .put (AA [0 ][0 ].toString (), data_A_00 );
@@ -174,19 +203,16 @@ public static void test6() {
174203 dict .put (xx [0 ].toString (), data_x_0 );
175204 dict .put (xx [1 ].toString (), data_x_1 );
176205
177- CloudConfig .setGlobalTarget ("job_local.conf" );
178- Node n = GraphBuilder .build (res );
179206 Session sess1 = new Session ();
180- CloudSD rlt = sess1 .runVec (n , dict );
181- System .out .println ("------------" );
182- for (double d : rlt .getData ())
183- System .out .println (d );
207+ CloudSD rlt = sess1 .runVec (res , dict );
208+ rlt .fetchToLocal ();
209+ assertEqual (new double []{13 ,9 ,10 ,13 }, rlt .getData ());
184210 }
185211
186212 /**
187213 * Automatic data dict split for matrices and vectors
188214 */
189- public static void test7 () {
215+ public static void testMatrixSplit3 () {
190216 int dim = 4 ;
191217 Matrix A = new Matrix ("A" , dim , dim );
192218 Vector x = new Vector ("x" , dim );
@@ -196,12 +222,9 @@ public static void test7() {
196222 SymVector xx = x .split (2 );
197223 //yy = AA * xx
198224 SymVector yy = (SymVector )(AA *xx );
199- System .out .println ("Test: yy=" +yy );
200- yy [0 ].runOn (new LCDevice ("2" ));
201- yy [1 ].runOn (new LCDevice ("1" ));
202225 // res = yy + y0
203- Expr res = new Concat (yy [0 ], yy [1 ])+y0 ;
204- res . runOn ( new LCDevice ( "0" ) );
226+ Expr res = CPU ( new Concat ( CPU ( yy [0 ]), CPU ( yy [1 ]) ) +y0 ) ;
227+ System . out . println ( "Test: res=" + res );
205228
206229 Map <String , double []> dict = new HashMap <String , double []>();
207230 /*
@@ -214,11 +237,10 @@ public static void test7() {
214237 dict .put (x .toString (), new double []{0 ,1 ,2 ,1 });
215238 dict .put (y0 .toString (), new double []{1 ,2 ,3 ,4 });
216239
217- Session sess1 = new Session ();
218- CloudSD rlt = sess1 .runVec (res , dict );
219- System .out .println ("Test done, fetch data:" );
220- for (double d : rlt .getData ())
221- System .out .println (d );
240+ Session sess = new Session ();
241+ CloudSD rlt = sess .runVec (res , dict );
242+ rlt .fetchToLocal ();
243+ assertEqual (new double []{13 ,9 ,10 ,13 }, rlt .getData ());
222244 }
223245
224246}
0 commit comments