Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c5f0ea1

Browse filesBrowse files
authored
Support Vector Machines (#319)
1 parent 877a40b commit c5f0ea1
Copy full SHA for c5f0ea1

File tree

4 files changed

+455
-19
lines changed
Filter options

4 files changed

+455
-19
lines changed

‎pgml-docs/docs/index.md

Copy file name to clipboardExpand all lines: pgml-docs/docs/index.md
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ article.md-content__inner.md-typeset a.md-content__button.md-icon {
3939
}
4040
</style>
4141

42-
<h1 align="center">End-to-end<br/>machine learning solution <br/>for everyone</h1>
42+
<h1 align="center">End-to-end<br/>machine learning platform <br/>for everyone</h1>
4343

4444
<p align="center" class="subtitle">
4545
Train and deploy models to make online predictions using only SQL, with an open source extension for Postgres. Manage your projects and visualize datasets using the built-in dashboard.

‎pgml-extension/pgml_rust/src/orm/algorithm.rs

Copy file name to clipboardExpand all lines: pgml-extension/pgml_rust/src/orm/algorithm.rs
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use serde::Deserialize;
66
pub enum Algorithm {
77
linear,
88
xgboost,
9+
svm,
910
}
1011

1112
impl std::str::FromStr for Algorithm {
@@ -15,6 +16,7 @@ impl std::str::FromStr for Algorithm {
1516
match input {
1617
"linear" => Ok(Algorithm::linear),
1718
"xgboost" => Ok(Algorithm::xgboost),
19+
"svm" => Ok(Algorithm::svm),
1820
_ => Err(()),
1921
}
2022
}
@@ -25,6 +27,7 @@ impl std::string::ToString for Algorithm {
2527
match *self {
2628
Algorithm::linear => "linear".to_string(),
2729
Algorithm::xgboost => "xgboost".to_string(),
30+
Algorithm::svm => "svm".to_string(),
2831
}
2932
}
3033
}

‎pgml-extension/pgml_rust/src/orm/estimator.rs

Copy file name to clipboardExpand all lines: pgml-extension/pgml_rust/src/orm/estimator.rs
+203-2Lines changed: 203 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
2626
}
2727
}
2828

29-
let (task, algorithm, data) = Spi::get_three_with_args::<String, String, Vec<u8>>(
29+
let (task, algorithm, model_id) = Spi::get_three_with_args::<String, String, i64>(
3030
"
31-
SELECT projects.task::TEXT, models.algorithm::TEXT, files.data
31+
SELECT projects.task::TEXT, models.algorithm::TEXT, models.id AS model_id
3232
FROM pgml_rust.files
3333
JOIN pgml_rust.models
3434
ON models.id = files.model_id
@@ -55,6 +55,17 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
5555
)
5656
}))
5757
.unwrap();
58+
59+
let (data, hyperparams) = Spi::get_two_with_args::<Vec<u8>, JsonB>(
60+
"SELECT data, hyperparams FROM pgml_rust.models
61+
INNER JOIN pgml_rust.files
62+
ON models.id = files.model_id WHERE models.id = $1
63+
LIMIT 1",
64+
vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())],
65+
);
66+
67+
let hyperparams = hyperparams.unwrap();
68+
5869
let data = data.unwrap_or_else(|| {
5970
panic!(
6071
"Project {} does not have a trained and deployed model.",
@@ -75,6 +86,54 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
7586
let bst = Booster::load_buffer(&*data).unwrap();
7687
Box::new(BoosterBox::new(bst))
7788
}
89+
Algorithm::svm => match &hyperparams.0.as_object().unwrap().get("kernel") {
90+
Some(kernel) => match kernel.as_str().unwrap_or("linear") {
91+
"poly" => {
92+
let estimator: smartcore::svm::svr::SVR<
93+
f32,
94+
Array2<f32>,
95+
smartcore::svm::PolynomialKernel<f32>,
96+
> = rmp_serde::from_read(&*data).unwrap();
97+
Box::new(estimator)
98+
}
99+
100+
"sigmoid" => {
101+
let estimator: smartcore::svm::svr::SVR<
102+
f32,
103+
Array2<f32>,
104+
smartcore::svm::SigmoidKernel<f32>,
105+
> = rmp_serde::from_read(&*data).unwrap();
106+
Box::new(estimator)
107+
}
108+
109+
"rbf" => {
110+
let estimator: smartcore::svm::svr::SVR<
111+
f32,
112+
Array2<f32>,
113+
smartcore::svm::RBFKernel<f32>,
114+
> = rmp_serde::from_read(&*data).unwrap();
115+
Box::new(estimator)
116+
}
117+
118+
_ => {
119+
let estimator: smartcore::svm::svr::SVR<
120+
f32,
121+
Array2<f32>,
122+
smartcore::svm::LinearKernel,
123+
> = rmp_serde::from_read(&*data).unwrap();
124+
Box::new(estimator)
125+
}
126+
},
127+
128+
None => {
129+
let estimator: smartcore::svm::svr::SVR<
130+
f32,
131+
Array2<f32>,
132+
smartcore::svm::LinearKernel,
133+
> = rmp_serde::from_read(&*data).unwrap();
134+
Box::new(estimator)
135+
}
136+
},
78137
},
79138
Task::classification => match algorithm {
80139
Algorithm::linear => {
@@ -88,6 +147,54 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
88147
let bst = Booster::load_buffer(&*data).unwrap();
89148
Box::new(BoosterBox::new(bst))
90149
}
150+
Algorithm::svm => match &hyperparams.0.as_object().unwrap().get("kernel") {
151+
Some(kernel) => match kernel.as_str().unwrap_or("linear") {
152+
"poly" => {
153+
let estimator: smartcore::svm::svc::SVC<
154+
f32,
155+
Array2<f32>,
156+
smartcore::svm::PolynomialKernel<f32>,
157+
> = rmp_serde::from_read(&*data).unwrap();
158+
Box::new(estimator)
159+
}
160+
161+
"sigmoid" => {
162+
let estimator: smartcore::svm::svc::SVC<
163+
f32,
164+
Array2<f32>,
165+
smartcore::svm::SigmoidKernel<f32>,
166+
> = rmp_serde::from_read(&*data).unwrap();
167+
Box::new(estimator)
168+
}
169+
170+
"rbf" => {
171+
let estimator: smartcore::svm::svc::SVC<
172+
f32,
173+
Array2<f32>,
174+
smartcore::svm::RBFKernel<f32>,
175+
> = rmp_serde::from_read(&*data).unwrap();
176+
Box::new(estimator)
177+
}
178+
179+
_ => {
180+
let estimator: smartcore::svm::svc::SVC<
181+
f32,
182+
Array2<f32>,
183+
smartcore::svm::LinearKernel,
184+
> = rmp_serde::from_read(&*data).unwrap();
185+
Box::new(estimator)
186+
}
187+
},
188+
189+
None => {
190+
let estimator: smartcore::svm::svc::SVC<
191+
f32,
192+
Array2<f32>,
193+
smartcore::svm::LinearKernel,
194+
> = rmp_serde::from_read(&*data).unwrap();
195+
Box::new(estimator)
196+
}
197+
},
91198
},
92199
};
93200

@@ -194,6 +301,100 @@ impl Estimator for smartcore::linear::logistic_regression::LogisticRegression<f3
194301
}
195302
}
196303

304+
// All the SVM kernels :popcorn:
305+
306+
#[typetag::serialize]
307+
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::LinearKernel> {
308+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
309+
test_smartcore(self, task, data)
310+
}
311+
312+
fn predict(&self, features: Vec<f32>) -> f32 {
313+
predict_smartcore(self, features)
314+
}
315+
}
316+
317+
#[typetag::serialize]
318+
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::LinearKernel> {
319+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
320+
test_smartcore(self, task, data)
321+
}
322+
323+
fn predict(&self, features: Vec<f32>) -> f32 {
324+
predict_smartcore(self, features)
325+
}
326+
}
327+
328+
#[typetag::serialize]
329+
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::SigmoidKernel<f32>> {
330+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
331+
test_smartcore(self, task, data)
332+
}
333+
334+
fn predict(&self, features: Vec<f32>) -> f32 {
335+
predict_smartcore(self, features)
336+
}
337+
}
338+
339+
#[typetag::serialize]
340+
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::SigmoidKernel<f32>> {
341+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
342+
test_smartcore(self, task, data)
343+
}
344+
345+
fn predict(&self, features: Vec<f32>) -> f32 {
346+
predict_smartcore(self, features)
347+
}
348+
}
349+
350+
#[typetag::serialize]
351+
impl Estimator
352+
for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::PolynomialKernel<f32>>
353+
{
354+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
355+
test_smartcore(self, task, data)
356+
}
357+
358+
fn predict(&self, features: Vec<f32>) -> f32 {
359+
predict_smartcore(self, features)
360+
}
361+
}
362+
363+
#[typetag::serialize]
364+
impl Estimator
365+
for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::PolynomialKernel<f32>>
366+
{
367+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
368+
test_smartcore(self, task, data)
369+
}
370+
371+
fn predict(&self, features: Vec<f32>) -> f32 {
372+
predict_smartcore(self, features)
373+
}
374+
}
375+
376+
#[typetag::serialize]
377+
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>> {
378+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
379+
test_smartcore(self, task, data)
380+
}
381+
382+
fn predict(&self, features: Vec<f32>) -> f32 {
383+
predict_smartcore(self, features)
384+
}
385+
}
386+
387+
#[typetag::serialize]
388+
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>> {
389+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
390+
test_smartcore(self, task, data)
391+
}
392+
393+
fn predict(&self, features: Vec<f32>) -> f32 {
394+
predict_smartcore(self, features)
395+
}
396+
}
397+
197398
pub struct BoosterBox {
198399
contents: Box<xgboost::Booster>,
199400
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.