Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 7bac309

Browse filesBrowse files
authored
PyO3 & Scikit introduction (#323)
1 parent 5360b56 commit 7bac309
Copy full SHA for 7bac309

File tree

Expand file treeCollapse file tree

10 files changed

+449
-26
lines changed
Filter options
Expand file treeCollapse file tree

10 files changed

+449
-26
lines changed

‎pgml-extension/pgml_rust/Cargo.toml

Copy file name to clipboardExpand all lines: pgml-extension/pgml_rust/Cargo.toml
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ serde = { version = "1.0.2" }
2929
serde_json = { version = "1.0.85" }
3030
rmp-serde = { version = "1.1.0" }
3131
typetag = "0.2"
32+
pyo3 = { version = "0.17", features = ["auto-initialize"] }
3233
heapless = "0.7.13"
3334

3435
[dev-dependencies]

‎pgml-extension/pgml_rust/control

Copy file name to clipboardExpand all lines: pgml-extension/pgml_rust/control
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Version: VERSION
33
Section: base
44
Priority: optional
55
Architecture: ARCH
6-
Depends: postgresql-PGVERSION, libopenblas-dev, postgresql-server-dev-PGVERSION
6+
Depends: postgresql-PGVERSION, libopenblas-dev, postgresql-server-dev-PGVERSION python3-numpy python3-sklearn python3 python3-dev
77
Maintainer: PostgresML <team@postgresml.org>
88
Description: PostgresML - machine learning with PostgreSQL
99
PostgresML is a PostgreSQL extension that allows to do machine

‎pgml-extension/pgml_rust/sql/schema.sql

Copy file name to clipboardExpand all lines: pgml-extension/pgml_rust/sql/schema.sql
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.models(
7777
project_id BIGINT NOT NULL,
7878
snapshot_id BIGINT NOT NULL,
7979
algorithm TEXT NOT NULL,
80+
backend TEXT DEFAULT 'smartcore',
8081
hyperparams JSONB NOT NULL,
8182
status TEXT NOT NULL,
8283
metrics JSONB,
+42Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use pgx::*;
2+
use serde::Deserialize;
3+
4+
#[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)]
5+
#[allow(non_camel_case_types)]
6+
pub enum Backend {
7+
xgboost,
8+
torch,
9+
lightdbm,
10+
sklearn,
11+
smartcore,
12+
linfa,
13+
}
14+
15+
impl std::str::FromStr for Backend {
16+
type Err = ();
17+
18+
fn from_str(input: &str) -> Result<Backend, Self::Err> {
19+
match input {
20+
"xgboost" => Ok(Backend::xgboost),
21+
"torch" => Ok(Backend::torch),
22+
"lightdbm" => Ok(Backend::lightdbm),
23+
"sklearn" => Ok(Backend::sklearn),
24+
"smartcore" => Ok(Backend::smartcore),
25+
"linfa" => Ok(Backend::linfa),
26+
_ => Err(()),
27+
}
28+
}
29+
}
30+
31+
impl std::string::ToString for Backend {
32+
fn to_string(&self) -> String {
33+
match *self {
34+
Backend::xgboost => "xgboost".to_string(),
35+
Backend::torch => "torch".to_string(),
36+
Backend::lightdbm => "lightdbm".to_string(),
37+
Backend::sklearn => "sklearn".to_string(),
38+
Backend::smartcore => "smartcore".to_string(),
39+
Backend::linfa => "linfa".to_string(),
40+
}
41+
}
42+
}
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod backend;
2+
pub mod sklearn;
+145Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
use pgx::*;
2+
use pyo3::prelude::*;
3+
use pyo3::types::PyTuple;
4+
5+
use std::collections::HashMap;
6+
7+
use crate::orm::dataset::Dataset;
8+
use crate::orm::estimator::SklearnBox;
9+
10+
#[pg_extern]
11+
pub fn sklearn_version() -> String {
12+
let mut version = String::new();
13+
14+
Python::with_gil(|py| {
15+
let sklearn = py.import("sklearn").unwrap();
16+
version = sklearn.getattr("__version__").unwrap().extract().unwrap();
17+
});
18+
19+
version
20+
}
21+
22+
pub fn sklearn_train(
23+
algorithm_name: &str,
24+
dataset: &Dataset,
25+
hyperparams: HashMap<String, f32>,
26+
) -> SklearnBox {
27+
let module = include_str!(concat!(
28+
env!("CARGO_MANIFEST_DIR"),
29+
"/src/backends/wrappers.py"
30+
));
31+
32+
let estimator = Python::with_gil(|py| -> Py<PyAny> {
33+
let module = PyModule::from_code(py, module, "", "").unwrap();
34+
let estimator: Py<PyAny> = module.getattr("estimator").unwrap().into();
35+
36+
let train: Py<PyAny> = estimator
37+
.call1(
38+
py,
39+
PyTuple::new(
40+
py,
41+
&[
42+
String::from(algorithm_name).into_py(py),
43+
dataset.num_features.into_py(py),
44+
hyperparams.into_py(py),
45+
],
46+
),
47+
)
48+
.unwrap();
49+
50+
train
51+
.call1(
52+
py,
53+
PyTuple::new(py, &[dataset.x_train(), dataset.y_train()]),
54+
)
55+
.unwrap()
56+
});
57+
58+
SklearnBox::new(estimator)
59+
}
60+
61+
pub fn sklearn_test(estimator: &SklearnBox, x_test: &[f32], num_features: usize) -> Vec<f32> {
62+
let module = include_str!(concat!(
63+
env!("CARGO_MANIFEST_DIR"),
64+
"/src/backends/wrappers.py"
65+
));
66+
67+
let y_hat: Vec<f32> = Python::with_gil(|py| -> Vec<f32> {
68+
let module = PyModule::from_code(py, module, "", "").unwrap();
69+
let predictor = module.getattr("predictor").unwrap();
70+
let predict = predictor
71+
.call1(PyTuple::new(
72+
py,
73+
&[estimator.contents.as_ref(), &num_features.into_py(py)],
74+
))
75+
.unwrap();
76+
77+
predict
78+
.call1(PyTuple::new(py, &[x_test]))
79+
.unwrap()
80+
.extract()
81+
.unwrap()
82+
});
83+
84+
y_hat
85+
}
86+
87+
pub fn sklearn_predict(estimator: &SklearnBox, x: &[f32]) -> Vec<f32> {
88+
let module = include_str!(concat!(
89+
env!("CARGO_MANIFEST_DIR"),
90+
"/src/backends/wrappers.py"
91+
));
92+
93+
let y_hat: Vec<f32> = Python::with_gil(|py| -> Vec<f32> {
94+
let module = PyModule::from_code(py, module, "", "").unwrap();
95+
let predictor = module.getattr("predictor").unwrap();
96+
let predict = predictor
97+
.call1(PyTuple::new(
98+
py,
99+
&[estimator.contents.as_ref(), &x.len().into_py(py)],
100+
))
101+
.unwrap();
102+
103+
predict
104+
.call1(PyTuple::new(py, &[x]))
105+
.unwrap()
106+
.extract()
107+
.unwrap()
108+
});
109+
110+
y_hat
111+
}
112+
113+
pub fn sklearn_save(estimator: &SklearnBox) -> Vec<u8> {
114+
let module = include_str!(concat!(
115+
env!("CARGO_MANIFEST_DIR"),
116+
"/src/backends/wrappers.py"
117+
));
118+
119+
Python::with_gil(|py| -> Vec<u8> {
120+
let module = PyModule::from_code(py, module, "", "").unwrap();
121+
let save = module.getattr("save").unwrap();
122+
save.call1(PyTuple::new(py, &[estimator.contents.as_ref()]))
123+
.unwrap()
124+
.extract()
125+
.unwrap()
126+
})
127+
}
128+
129+
pub fn sklearn_load(data: &Vec<u8>) -> SklearnBox {
130+
let module = include_str!(concat!(
131+
env!("CARGO_MANIFEST_DIR"),
132+
"/src/backends/wrappers.py"
133+
));
134+
135+
Python::with_gil(|py| -> SklearnBox {
136+
let module = PyModule::from_code(py, module, "", "").unwrap();
137+
let load = module.getattr("load").unwrap();
138+
let estimator = load
139+
.call1(PyTuple::new(py, &[data]))
140+
.unwrap()
141+
.extract()
142+
.unwrap();
143+
SklearnBox::new(estimator)
144+
})
145+
}
+108Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import sklearn.linear_model
2+
import sklearn.kernel_ridge
3+
import sklearn.svm
4+
import sklearn.ensemble
5+
import sklearn.multioutput
6+
import sklearn.gaussian_process
7+
import sklearn.model_selection
8+
import numpy as np
9+
import pickle
10+
11+
_ALGORITHM_MAP = {
12+
"linear_regression": sklearn.linear_model.LinearRegression,
13+
"linear_classification": sklearn.linear_model.LogisticRegression,
14+
"ridge_regression": sklearn.linear_model.Ridge,
15+
"ridge_classification": sklearn.linear_model.RidgeClassifier,
16+
"lasso_regression": sklearn.linear_model.Lasso,
17+
"elastic_net_regression": sklearn.linear_model.ElasticNet,
18+
"least_angle_regression": sklearn.linear_model.Lars,
19+
"lasso_least_angle_regression": sklearn.linear_model.LassoLars,
20+
"orthoganl_matching_pursuit_regression": sklearn.linear_model.OrthogonalMatchingPursuit,
21+
"bayesian_ridge_regression": sklearn.linear_model.BayesianRidge,
22+
"automatic_relevance_determination_regression": sklearn.linear_model.ARDRegression,
23+
"stochastic_gradient_descent_regression": sklearn.linear_model.SGDRegressor,
24+
"stochastic_gradient_descent_classification": sklearn.linear_model.SGDClassifier,
25+
"perceptron_classification": sklearn.linear_model.Perceptron,
26+
"passive_aggressive_regression": sklearn.linear_model.PassiveAggressiveRegressor,
27+
"passive_aggressive_classification": sklearn.linear_model.PassiveAggressiveClassifier,
28+
"ransac_regression": sklearn.linear_model.RANSACRegressor,
29+
"theil_sen_regression": sklearn.linear_model.TheilSenRegressor,
30+
"huber_regression": sklearn.linear_model.HuberRegressor,
31+
"quantile_regression": sklearn.linear_model.QuantileRegressor,
32+
"kernel_ridge_regression": sklearn.kernel_ridge.KernelRidge,
33+
"gaussian_process_regression": sklearn.gaussian_process.GaussianProcessRegressor,
34+
"gaussian_process_classification": sklearn.gaussian_process.GaussianProcessClassifier,
35+
"svm_regression": sklearn.svm.SVR,
36+
"svm_classification": sklearn.svm.SVC,
37+
"nu_svm_regression": sklearn.svm.NuSVR,
38+
"nu_svm_classification": sklearn.svm.NuSVC,
39+
"linear_svm_regression": sklearn.svm.LinearSVR,
40+
"linear_svm_classification": sklearn.svm.LinearSVC,
41+
"ada_boost_regression": sklearn.ensemble.AdaBoostRegressor,
42+
"ada_boost_classification": sklearn.ensemble.AdaBoostClassifier,
43+
"bagging_regression": sklearn.ensemble.BaggingRegressor,
44+
"bagging_classification": sklearn.ensemble.BaggingClassifier,
45+
"extra_trees_regression": sklearn.ensemble.ExtraTreesRegressor,
46+
"extra_trees_classification": sklearn.ensemble.ExtraTreesClassifier,
47+
"gradient_boosting_trees_regression": sklearn.ensemble.GradientBoostingRegressor,
48+
"gradient_boosting_trees_classification": sklearn.ensemble.GradientBoostingClassifier,
49+
"hist_gradient_boosting_regression": sklearn.ensemble.HistGradientBoostingRegressor,
50+
"hist_gradient_boosting_classification": sklearn.ensemble.HistGradientBoostingClassifier,
51+
"random_forest_regression": sklearn.ensemble.RandomForestRegressor,
52+
"random_forest_classification": sklearn.ensemble.RandomForestClassifier,
53+
}
54+
55+
56+
def estimator(algorithm_name, num_features, hyperparams):
57+
return estimator_joint(algorithm_name, num_features, 1, hyperparams)
58+
59+
60+
def estimator_joint(algorithm_name, num_features, num_targets, hyperparams):
61+
if hyperparams is None:
62+
hyperparams = {}
63+
64+
def train(X_train, y_train):
65+
instance = _ALGORITHM_MAP[algorithm_name](**hyperparams)
66+
67+
X_train = np.asarray(X_train).reshape((-1, num_features))
68+
69+
# Only support single value models for just now.
70+
y_train = np.asarray(y_train).reshape((-1, num_targets))
71+
72+
instance.fit(X_train, y_train)
73+
return instance
74+
75+
return train
76+
77+
78+
def test(estimator, X_test):
79+
y_hat = estimator.predict(X_test)
80+
81+
# Single value models only just for now.
82+
return list(np.asarray(y_hat).flatten())
83+
84+
85+
def predictor(estimator, num_features):
86+
return predictor_joint(estimator, num_features, 1)
87+
88+
89+
def predictor_joint(estimator, num_features, num_targets):
90+
def predict(X):
91+
X = np.asarray(X).reshape((-1, num_features))
92+
y_hat = estimator.predict(X)
93+
94+
# Only support single value models for just now.
95+
if num_targets == 1:
96+
return list(np.asarray(y_hat).flatten())
97+
else:
98+
return list(y_hat)
99+
100+
return predict
101+
102+
103+
def save(estimator):
104+
return pickle.dumps(estimator)
105+
106+
107+
def load(data):
108+
return pickle.loads(bytes(data))

‎pgml-extension/pgml_rust/src/lib.rs

Copy file name to clipboardExpand all lines: pgml-extension/pgml_rust/src/lib.rs
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::sync::Mutex;
1010
use xgboost::{Booster, DMatrix};
1111

1212
pub mod api;
13+
pub mod backends;
1314
pub mod orm;
1415
pub mod vectors;
1516

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.