@@ -26,9 +26,9 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
26
26
}
27
27
}
28
28
29
- let ( task, algorithm, data ) = Spi :: get_three_with_args :: < String , String , Vec < u8 > > (
29
+ let ( task, algorithm, model_id ) = Spi :: get_three_with_args :: < String , String , i64 > (
30
30
"
31
- SELECT projects.task::TEXT, models.algorithm::TEXT, files.data
31
+ SELECT projects.task::TEXT, models.algorithm::TEXT, models.id AS model_id
32
32
FROM pgml_rust.files
33
33
JOIN pgml_rust.models
34
34
ON models.id = files.model_id
@@ -55,6 +55,17 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
55
55
)
56
56
} ) )
57
57
. unwrap ( ) ;
58
+
59
+ let ( data, hyperparams) = Spi :: get_two_with_args :: < Vec < u8 > , JsonB > (
60
+ "SELECT data, hyperparams FROM pgml_rust.models
61
+ INNER JOIN pgml_rust.files
62
+ ON models.id = files.model_id WHERE models.id = $1
63
+ LIMIT 1" ,
64
+ vec ! [ ( PgBuiltInOids :: INT8OID . oid( ) , model_id. into_datum( ) ) ] ,
65
+ ) ;
66
+
67
+ let hyperparams = hyperparams. unwrap ( ) ;
68
+
58
69
let data = data. unwrap_or_else ( || {
59
70
panic ! (
60
71
"Project {} does not have a trained and deployed model." ,
@@ -75,6 +86,54 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
75
86
let bst = Booster :: load_buffer ( & * data) . unwrap ( ) ;
76
87
Box :: new ( BoosterBox :: new ( bst) )
77
88
}
89
+ Algorithm :: svm => match & hyperparams. 0 . as_object ( ) . unwrap ( ) . get ( "kernel" ) {
90
+ Some ( kernel) => match kernel. as_str ( ) . unwrap_or ( "linear" ) {
91
+ "poly" => {
92
+ let estimator: smartcore:: svm:: svr:: SVR <
93
+ f32 ,
94
+ Array2 < f32 > ,
95
+ smartcore:: svm:: PolynomialKernel < f32 > ,
96
+ > = rmp_serde:: from_read ( & * data) . unwrap ( ) ;
97
+ Box :: new ( estimator)
98
+ }
99
+
100
+ "sigmoid" => {
101
+ let estimator: smartcore:: svm:: svr:: SVR <
102
+ f32 ,
103
+ Array2 < f32 > ,
104
+ smartcore:: svm:: SigmoidKernel < f32 > ,
105
+ > = rmp_serde:: from_read ( & * data) . unwrap ( ) ;
106
+ Box :: new ( estimator)
107
+ }
108
+
109
+ "rbf" => {
110
+ let estimator: smartcore:: svm:: svr:: SVR <
111
+ f32 ,
112
+ Array2 < f32 > ,
113
+ smartcore:: svm:: RBFKernel < f32 > ,
114
+ > = rmp_serde:: from_read ( & * data) . unwrap ( ) ;
115
+ Box :: new ( estimator)
116
+ }
117
+
118
+ _ => {
119
+ let estimator: smartcore:: svm:: svr:: SVR <
120
+ f32 ,
121
+ Array2 < f32 > ,
122
+ smartcore:: svm:: LinearKernel ,
123
+ > = rmp_serde:: from_read ( & * data) . unwrap ( ) ;
124
+ Box :: new ( estimator)
125
+ }
126
+ } ,
127
+
128
+ None => {
129
+ let estimator: smartcore:: svm:: svr:: SVR <
130
+ f32 ,
131
+ Array2 < f32 > ,
132
+ smartcore:: svm:: LinearKernel ,
133
+ > = rmp_serde:: from_read ( & * data) . unwrap ( ) ;
134
+ Box :: new ( estimator)
135
+ }
136
+ } ,
78
137
} ,
79
138
Task :: classification => match algorithm {
80
139
Algorithm :: linear => {
@@ -88,6 +147,54 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
88
147
let bst = Booster :: load_buffer ( & * data) . unwrap ( ) ;
89
148
Box :: new ( BoosterBox :: new ( bst) )
90
149
}
150
+ Algorithm :: svm => match & hyperparams. 0 . as_object ( ) . unwrap ( ) . get ( "kernel" ) {
151
+ Some ( kernel) => match kernel. as_str ( ) . unwrap_or ( "linear" ) {
152
+ "poly" => {
153
+ let estimator: smartcore:: svm:: svc:: SVC <
154
+ f32 ,
155
+ Array2 < f32 > ,
156
+ smartcore:: svm:: PolynomialKernel < f32 > ,
157
+ > = rmp_serde:: from_read ( & * data) . unwrap ( ) ;
158
+ Box :: new ( estimator)
159
+ }
160
+
161
+ "sigmoid" => {
162
+ let estimator: smartcore:: svm:: svc:: SVC <
163
+ f32 ,
164
+ Array2 < f32 > ,
165
+ smartcore:: svm:: SigmoidKernel < f32 > ,
166
+ > = rmp_serde:: from_read ( & * data) . unwrap ( ) ;
167
+ Box :: new ( estimator)
168
+ }
169
+
170
+ "rbf" => {
171
+ let estimator: smartcore:: svm:: svc:: SVC <
172
+ f32 ,
173
+ Array2 < f32 > ,
174
+ smartcore:: svm:: RBFKernel < f32 > ,
175
+ > = rmp_serde:: from_read ( & * data) . unwrap ( ) ;
176
+ Box :: new ( estimator)
177
+ }
178
+
179
+ _ => {
180
+ let estimator: smartcore:: svm:: svc:: SVC <
181
+ f32 ,
182
+ Array2 < f32 > ,
183
+ smartcore:: svm:: LinearKernel ,
184
+ > = rmp_serde:: from_read ( & * data) . unwrap ( ) ;
185
+ Box :: new ( estimator)
186
+ }
187
+ } ,
188
+
189
+ None => {
190
+ let estimator: smartcore:: svm:: svc:: SVC <
191
+ f32 ,
192
+ Array2 < f32 > ,
193
+ smartcore:: svm:: LinearKernel ,
194
+ > = rmp_serde:: from_read ( & * data) . unwrap ( ) ;
195
+ Box :: new ( estimator)
196
+ }
197
+ } ,
91
198
} ,
92
199
} ;
93
200
@@ -194,6 +301,100 @@ impl Estimator for smartcore::linear::logistic_regression::LogisticRegression<f3
194
301
}
195
302
}
196
303
304
+ // All the SVM kernels :popcorn:
305
+
306
+ #[ typetag:: serialize]
307
+ impl Estimator for smartcore:: svm:: svc:: SVC < f32 , Array2 < f32 > , smartcore:: svm:: LinearKernel > {
308
+ fn test ( & self , task : Task , data : & Dataset ) -> HashMap < String , f32 > {
309
+ test_smartcore ( self , task, data)
310
+ }
311
+
312
+ fn predict ( & self , features : Vec < f32 > ) -> f32 {
313
+ predict_smartcore ( self , features)
314
+ }
315
+ }
316
+
317
+ #[ typetag:: serialize]
318
+ impl Estimator for smartcore:: svm:: svr:: SVR < f32 , Array2 < f32 > , smartcore:: svm:: LinearKernel > {
319
+ fn test ( & self , task : Task , data : & Dataset ) -> HashMap < String , f32 > {
320
+ test_smartcore ( self , task, data)
321
+ }
322
+
323
+ fn predict ( & self , features : Vec < f32 > ) -> f32 {
324
+ predict_smartcore ( self , features)
325
+ }
326
+ }
327
+
328
+ #[ typetag:: serialize]
329
+ impl Estimator for smartcore:: svm:: svc:: SVC < f32 , Array2 < f32 > , smartcore:: svm:: SigmoidKernel < f32 > > {
330
+ fn test ( & self , task : Task , data : & Dataset ) -> HashMap < String , f32 > {
331
+ test_smartcore ( self , task, data)
332
+ }
333
+
334
+ fn predict ( & self , features : Vec < f32 > ) -> f32 {
335
+ predict_smartcore ( self , features)
336
+ }
337
+ }
338
+
339
+ #[ typetag:: serialize]
340
+ impl Estimator for smartcore:: svm:: svr:: SVR < f32 , Array2 < f32 > , smartcore:: svm:: SigmoidKernel < f32 > > {
341
+ fn test ( & self , task : Task , data : & Dataset ) -> HashMap < String , f32 > {
342
+ test_smartcore ( self , task, data)
343
+ }
344
+
345
+ fn predict ( & self , features : Vec < f32 > ) -> f32 {
346
+ predict_smartcore ( self , features)
347
+ }
348
+ }
349
+
350
+ #[ typetag:: serialize]
351
+ impl Estimator
352
+ for smartcore:: svm:: svc:: SVC < f32 , Array2 < f32 > , smartcore:: svm:: PolynomialKernel < f32 > >
353
+ {
354
+ fn test ( & self , task : Task , data : & Dataset ) -> HashMap < String , f32 > {
355
+ test_smartcore ( self , task, data)
356
+ }
357
+
358
+ fn predict ( & self , features : Vec < f32 > ) -> f32 {
359
+ predict_smartcore ( self , features)
360
+ }
361
+ }
362
+
363
+ #[ typetag:: serialize]
364
+ impl Estimator
365
+ for smartcore:: svm:: svr:: SVR < f32 , Array2 < f32 > , smartcore:: svm:: PolynomialKernel < f32 > >
366
+ {
367
+ fn test ( & self , task : Task , data : & Dataset ) -> HashMap < String , f32 > {
368
+ test_smartcore ( self , task, data)
369
+ }
370
+
371
+ fn predict ( & self , features : Vec < f32 > ) -> f32 {
372
+ predict_smartcore ( self , features)
373
+ }
374
+ }
375
+
376
+ #[ typetag:: serialize]
377
+ impl Estimator for smartcore:: svm:: svc:: SVC < f32 , Array2 < f32 > , smartcore:: svm:: RBFKernel < f32 > > {
378
+ fn test ( & self , task : Task , data : & Dataset ) -> HashMap < String , f32 > {
379
+ test_smartcore ( self , task, data)
380
+ }
381
+
382
+ fn predict ( & self , features : Vec < f32 > ) -> f32 {
383
+ predict_smartcore ( self , features)
384
+ }
385
+ }
386
+
387
+ #[ typetag:: serialize]
388
+ impl Estimator for smartcore:: svm:: svr:: SVR < f32 , Array2 < f32 > , smartcore:: svm:: RBFKernel < f32 > > {
389
+ fn test ( & self , task : Task , data : & Dataset ) -> HashMap < String , f32 > {
390
+ test_smartcore ( self , task, data)
391
+ }
392
+
393
+ fn predict ( & self , features : Vec < f32 > ) -> f32 {
394
+ predict_smartcore ( self , features)
395
+ }
396
+ }
397
+
197
398
pub struct BoosterBox {
198
399
contents : Box < xgboost:: Booster > ,
199
400
}
0 commit comments