From fb7a1fa5290005d338608dc0320d085743147af8 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 6 Oct 2022 09:19:43 -0700 Subject: [PATCH 01/12] Add serialization for LogisticRegression --- algorithms/linfa-logistic/Cargo.toml | 1 + algorithms/linfa-logistic/src/lib.rs | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/algorithms/linfa-logistic/Cargo.toml b/algorithms/linfa-logistic/Cargo.toml index be32724c2..1b09a83e6 100644 --- a/algorithms/linfa-logistic/Cargo.toml +++ b/algorithms/linfa-logistic/Cargo.toml @@ -26,3 +26,4 @@ linfa = { version = "0.6.0", path = "../..", features=["serde"] } [dev-dependencies] approx = "0.4" linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["winequality"] } +rmp-serde = "1" diff --git a/algorithms/linfa-logistic/src/lib.rs b/algorithms/linfa-logistic/src/lib.rs index a50055a69..a0018dcf7 100644 --- a/algorithms/linfa-logistic/src/lib.rs +++ b/algorithms/linfa-logistic/src/lib.rs @@ -30,6 +30,7 @@ use ndarray::{ Dimension, IntoDimension, Ix1, Ix2, RemoveAxis, Slice, Zip, }; use ndarray_stats::QuantileExt; +use serde::{Deserialize, Serialize}; use std::default::Default; mod argmin_param; @@ -524,8 +525,8 @@ fn multi_logistic_grad>( } /// A fitted logistic regression which can make predictions -#[derive(PartialEq, Debug, Clone)] -pub struct FittedLogisticRegression { +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] +pub struct FittedLogisticRegression { threshold: F, intercept: F, params: Array1, @@ -610,8 +611,8 @@ impl> } /// A fitted multinomial logistic regression which can make predictions -#[derive(PartialEq, Debug, Clone)] -pub struct MultiFittedLogisticRegression { +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] +pub struct MultiFittedLogisticRegression { intercept: Array1, params: Array2, classes: Vec, @@ -685,8 +686,8 @@ impl> } } -#[derive(PartialEq, Debug, Clone)] -struct ClassLabel { +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] +struct ClassLabel { class: C, label: F, } @@ -1066,6 +1067,15 @@ mod test { &res.predict(dataset.records()), dataset.targets().as_single_targets() ); + + // Test serialization + let ser = rmp_serde::to_vec(&res).unwrap(); + let unser: FittedLogisticRegression = rmp_serde::from_slice(&ser).unwrap(); + + let x = array![[1.0]]; + let y_hat = unser.predict(&x); + + assert!(y_hat[0] == 0.0); } #[test] From 378c3b4f13fab2fb58419d5d4eea3ad2ebebd671 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 6 Oct 2022 10:38:48 -0700 Subject: [PATCH 02/12] Serialization for multi-class --- algorithms/linfa-logistic/src/hyperparams.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/algorithms/linfa-logistic/src/hyperparams.rs b/algorithms/linfa-logistic/src/hyperparams.rs index 523062f3c..1447198db 100644 --- a/algorithms/linfa-logistic/src/hyperparams.rs +++ b/algorithms/linfa-logistic/src/hyperparams.rs @@ -4,13 +4,15 @@ use ndarray::{Array, Dimension}; use crate::error::Error; use crate::float::Float; +use serde::{Deserialize, Serialize}; + /// A generalized logistic regression type that specializes as either binomial logistic regression /// or multinomial logistic regression. -#[derive(Debug, Clone, PartialEq)] -pub struct LogisticRegressionParams(LogisticRegressionValidParams); +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LogisticRegressionParams(LogisticRegressionValidParams); -#[derive(Debug, Clone, PartialEq)] -pub struct LogisticRegressionValidParams { +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LogisticRegressionValidParams { pub(crate) alpha: F, pub(crate) fit_intercept: bool, pub(crate) max_iterations: u64, From fe3ae530c7e9f8249f5171caa3f4f474799acf92 Mon Sep 17 00:00:00 2001 From: Gorka Kobeaga Date: Fri, 7 Oct 2022 01:18:56 +0200 Subject: [PATCH 03/12] Float type restriction with handwritten bounds --- algorithms/linfa-logistic/src/hyperparams.rs | 6 ++++-- algorithms/linfa-logistic/src/lib.rs | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/algorithms/linfa-logistic/src/hyperparams.rs b/algorithms/linfa-logistic/src/hyperparams.rs index 1447198db..1705aed3a 100644 --- a/algorithms/linfa-logistic/src/hyperparams.rs +++ b/algorithms/linfa-logistic/src/hyperparams.rs @@ -9,10 +9,12 @@ use serde::{Deserialize, Serialize}; /// A generalized logistic regression type that specializes as either binomial logistic regression /// or multinomial logistic regression. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct LogisticRegressionParams(LogisticRegressionValidParams); +#[serde(bound(deserialize = "D: Deserialize<'de>"))] +pub struct LogisticRegressionParams(LogisticRegressionValidParams); #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct LogisticRegressionValidParams { +#[serde(bound(deserialize = "D: Deserialize<'de>"))] +pub struct LogisticRegressionValidParams { pub(crate) alpha: F, pub(crate) fit_intercept: bool, pub(crate) max_iterations: u64, diff --git a/algorithms/linfa-logistic/src/lib.rs b/algorithms/linfa-logistic/src/lib.rs index a0018dcf7..edc3b66cd 100644 --- a/algorithms/linfa-logistic/src/lib.rs +++ b/algorithms/linfa-logistic/src/lib.rs @@ -526,7 +526,8 @@ fn multi_logistic_grad>( /// A fitted logistic regression which can make predictions #[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] -pub struct FittedLogisticRegression { +#[serde(bound(deserialize = "C: Deserialize<'de>"))] +pub struct FittedLogisticRegression { threshold: F, intercept: F, params: Array1, From c44940bca6adf929207d9662b699e03248f6fa18 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 16 Oct 2022 19:42:01 -0700 Subject: [PATCH 04/12] Confusion matrix should use labels from predictions and ground truth --- src/dataset/mod.rs | 13 +++++++++++++ src/metrics_classification.rs | 15 ++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index b04e48109..077302a59 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -323,6 +323,19 @@ pub trait Labels { fn labels(&self) -> Vec { self.label_set().into_iter().flatten().collect() } + + fn combined_labels(&self, other: Vec) -> Vec { + let mut combined = self.labels(); + combined.extend(other.clone()); + + combined + .iter() + .map(|x| x) + .collect::>() + .into_iter() + .map(|x| x.clone()) + .collect() + } } #[cfg(test)] diff --git a/src/metrics_classification.rs b/src/metrics_classification.rs index 3cf2f30f5..74d053dd7 100644 --- a/src/metrics_classification.rs +++ b/src/metrics_classification.rs @@ -290,7 +290,7 @@ where return Err(Error::MismatchedShapes(targets.len(), ground_truth.len())); } - let classes = self.labels(); + let classes = self.combined_labels(ground_truth.labels()); let indices = map_prediction_to_idx( targets.as_slice().unwrap(), @@ -636,6 +636,19 @@ mod tests { ); } + #[test] + fn test_division_by_zero_cm() { + let ground_truth = Array1::from(vec![1, 1, 0, 1, 0, 1]); + let predicted = Array1::from(vec![0, 0, 0, 0, 0, 0]); + let labels = array![0, 1]; + + let x = predicted.confusion_matrix(ground_truth).unwrap(); + + let f1 = x.f1_score(); + + assert!(f1.is_nan()); + } + #[test] fn test_roc_curve() { let predicted = ArrayView1::from(&[0.1, 0.3, 0.5, 0.7, 0.8, 0.9]).mapv(Pr::new); From d91de55080fd715045d583946c90c02d936f5247 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 16 Oct 2022 19:54:15 -0700 Subject: [PATCH 05/12] Clippy fixes --- src/dataset/mod.rs | 5 ++--- src/metrics_classification.rs | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 077302a59..e395364dd 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -326,14 +326,13 @@ pub trait Labels { fn combined_labels(&self, other: Vec) -> Vec { let mut combined = self.labels(); - combined.extend(other.clone()); + combined.extend(other); combined .iter() - .map(|x| x) .collect::>() .into_iter() - .map(|x| x.clone()) + .cloned() .collect() } } diff --git a/src/metrics_classification.rs b/src/metrics_classification.rs index 74d053dd7..cbac7e97c 100644 --- a/src/metrics_classification.rs +++ b/src/metrics_classification.rs @@ -640,7 +640,6 @@ mod tests { fn test_division_by_zero_cm() { let ground_truth = Array1::from(vec![1, 1, 0, 1, 0, 1]); let predicted = Array1::from(vec![0, 0, 0, 0, 0, 0]); - let labels = array![0, 1]; let x = predicted.confusion_matrix(ground_truth).unwrap(); From 4ac3ec88b16a206682c776637b2fe1a3a365ed09 Mon Sep 17 00:00:00 2001 From: Lev Date: Sun, 16 Oct 2022 21:58:52 -0700 Subject: [PATCH 06/12] This is the correct test --- src/metrics_classification.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/metrics_classification.rs b/src/metrics_classification.rs index cbac7e97c..07ad7beeb 100644 --- a/src/metrics_classification.rs +++ b/src/metrics_classification.rs @@ -641,11 +641,10 @@ mod tests { let ground_truth = Array1::from(vec![1, 1, 0, 1, 0, 1]); let predicted = Array1::from(vec![0, 0, 0, 0, 0, 0]); - let x = predicted.confusion_matrix(ground_truth).unwrap(); - + let x = ground_truth.confusion_matrix(predicted).unwrap(); let f1 = x.f1_score(); - assert!(f1.is_nan()); + assert_eq!(f1, 0.5); } #[test] From 7dee254009462486b3cf5b3ac3ad93225d7b9e3f Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 10 Jan 2025 17:57:01 -0800 Subject: [PATCH 07/12] fix warnings --- src/correlation.rs | 1 - src/dataset/impl_dataset.rs | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/correlation.rs b/src/correlation.rs index e77905813..52258022e 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -153,7 +153,6 @@ impl PearsonCorrelation { /// lamotrigine +0.47 (0.14) /// blood sugar level /// ``` - pub fn from_dataset, T>( dataset: &DatasetBase, T>, num_iter: Option, diff --git a/src/dataset/impl_dataset.rs b/src/dataset/impl_dataset.rs index b8148793f..a68ad9f31 100644 --- a/src/dataset/impl_dataset.rs +++ b/src/dataset/impl_dataset.rs @@ -208,7 +208,7 @@ where /// println!("{} => {}", x, y); /// } /// ``` - pub fn sample_iter(&'a self) -> Iter<'a, '_, F, T::Elem, T::Ix> { + pub fn sample_iter(&'a self) -> Iter<'a, 'a, F, T::Elem, T::Ix> { Iter::new(self.records.view(), self.targets.as_targets()) } } @@ -232,7 +232,7 @@ where /// /// This iterator produces dataset views with only a single feature, while the set of targets remain /// complete. It can be useful to compare each feature individual to all targets. - pub fn feature_iter(&'a self) -> DatasetIter<'a, '_, ArrayBase, T> { + pub fn feature_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase, T> { DatasetIter::new(self, true) } @@ -241,7 +241,7 @@ where /// This functions creates an iterator which produces dataset views complete records, but only /// a single target each. Useful to train multiple single target models for a multi-target /// dataset. - pub fn target_iter(&'a self) -> DatasetIter<'a, '_, ArrayBase, T> { + pub fn target_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase, T> { DatasetIter::new(self, false) } } @@ -732,7 +732,7 @@ where &'a mut self, k: usize, fit_closure: C, - ) -> impl Iterator, ArrayView>)> { + ) -> impl Iterator, ArrayView<'a, E, I>>)> { assert!(k > 0); assert!(k <= self.nsamples()); let samples_count = self.nsamples(); From e9904a8ebf0fd3ac25019dd27e96e84896213f6c Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 10 Jan 2025 18:01:21 -0800 Subject: [PATCH 08/12] remove lifetimes --- src/dataset/impl_dataset.rs | 2 +- src/metrics_clustering.rs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dataset/impl_dataset.rs b/src/dataset/impl_dataset.rs index a68ad9f31..af49ea25c 100644 --- a/src/dataset/impl_dataset.rs +++ b/src/dataset/impl_dataset.rs @@ -318,7 +318,7 @@ impl, R: Records> Labels for DatasetBase { } #[allow(clippy::type_complexity)] -impl<'a, 'b: 'a, F, L: Label, T, D> DatasetBase, T> +impl DatasetBase, T> where D: Data, T: AsSingleTargets + Labels, diff --git a/src/metrics_clustering.rs b/src/metrics_clustering.rs index b49f4c5be..5eab97cb8 100644 --- a/src/metrics_clustering.rs +++ b/src/metrics_clustering.rs @@ -63,9 +63,8 @@ impl DistanceCount { } impl< - 'a, F: Float, - L: 'a + Label, + L: Label, D: Data, T: AsSingleTargets + Labels, > SilhouetteScore for DatasetBase, T> From 4f8ccef4eb7f052a909236919362a1cc1b8b3389 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 10 Jan 2025 18:05:01 -0800 Subject: [PATCH 09/12] clippy lints --- algorithms/linfa-ftrl/Cargo.toml | 2 +- src/correlation.rs | 2 +- src/dataset/impl_dataset.rs | 10 +++++----- src/dataset/mod.rs | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/algorithms/linfa-ftrl/Cargo.toml b/algorithms/linfa-ftrl/Cargo.toml index 1b8abbb87..c5a2cacbe 100644 --- a/algorithms/linfa-ftrl/Cargo.toml +++ b/algorithms/linfa-ftrl/Cargo.toml @@ -24,7 +24,7 @@ version = "1.0" features = ["derive"] [dependencies] -ndarray = { version = "0.15.4", features = ["serde"] } +ndarray = { version = "0.15", features = ["serde"] } ndarray-rand = "0.14.0" argmin = { version = "0.9.0", default-features = false } argmin-math = { version = "0.3", features = ["ndarray_v0_15-nolinalg"] } diff --git a/src/correlation.rs b/src/correlation.rs index 52258022e..d157d10b0 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -128,7 +128,7 @@ impl PearsonCorrelation { /// /// * `dataset`: Data for the correlation analysis /// * `num_iter`: optionally number of iterations of the p-value test, if none then no p-value - /// are calculate + /// are calculated /// /// # Example /// diff --git a/src/dataset/impl_dataset.rs b/src/dataset/impl_dataset.rs index af49ea25c..cb1e28a90 100644 --- a/src/dataset/impl_dataset.rs +++ b/src/dataset/impl_dataset.rs @@ -680,8 +680,8 @@ where /// - `k`: the number of folds to apply to the dataset /// - `params`: the desired parameters for the fittable algorithm at hand /// - `fit_closure`: a closure of the type `(params, training_data) -> fitted_model` - /// that will be used to produce the trained model for each fold. The training data given in input - /// won't outlive the closure. + /// that will be used to produce the trained model for each fold. The training data given in + /// input won't outlive the closure. /// /// ## Returns /// @@ -794,9 +794,9 @@ where /// - `k`: the number of folds to apply /// - `parameters`: a list of models to compare /// - `eval`: closure used to evaluate the performance of each trained model. This closure is - /// called on the model output and validation targets of each fold and outputs the performance - /// score for each target. For single-target dataset the signature is `(Array1, Array1) -> - /// Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`. + /// called on the model output and validation targets of each fold and outputs the performance + /// score for each target. For single-target dataset the signature is `(Array1, Array1) -> + /// Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`. /// /// ### Returns /// diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 544d8243a..ab8f5417d 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -161,7 +161,7 @@ impl Deref for Pr { /// # Fields /// /// * `records`: a two-dimensional matrix with dimensionality (nsamples, nfeatures), in case of -/// kernel methods a quadratic matrix with dimensionality (nsamples, nsamples), which may be sparse +/// kernel methods a quadratic matrix with dimensionality (nsamples, nsamples), which may be sparse /// * `targets`: a two-/one-dimension matrix with dimensionality (nsamples, ntargets) /// * `weights`: optional weights for each sample with dimensionality (nsamples) /// * `feature_names`: optional descriptive feature names with dimensionality (nfeatures) @@ -170,7 +170,7 @@ impl Deref for Pr { /// /// * `R: Records`: generic over feature matrices or kernel matrices /// * `T`: generic over any `ndarray` matrix which can be used as targets. The `AsTargets` trait -/// bound is omitted here to avoid some repetition in implementation `src/dataset/impl_dataset.rs` +/// bound is omitted here to avoid some repetition in implementation `src/dataset/impl_dataset.rs` #[derive(Debug, Clone, PartialEq)] pub struct DatasetBase where From 5ec7b2f3161ded3044eeb23ee9c97050c68d2dd8 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 10 Jan 2025 19:15:12 -0800 Subject: [PATCH 10/12] fix ownership --- algorithms/linfa-kernel/src/inner.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithms/linfa-kernel/src/inner.rs b/algorithms/linfa-kernel/src/inner.rs index a0bdfb036..00c712878 100644 --- a/algorithms/linfa-kernel/src/inner.rs +++ b/algorithms/linfa-kernel/src/inner.rs @@ -61,7 +61,7 @@ impl Inner for CsMat { type Elem = F; fn dot(&self, rhs: &ArrayView2) -> Array2 { - self.mul(rhs) + self.mul(&rhs.to_owned()) } fn sum(&self) -> Array1 { let mut sum = Array1::zeros(self.cols()); @@ -106,7 +106,7 @@ impl<'a, F: Float> Inner for CsMatView<'a, F> { type Elem = F; fn dot(&self, rhs: &ArrayView2) -> Array2 { - self.mul(rhs) + self.mul(&rhs.to_owned()) } fn sum(&self) -> Array1 { let mut sum = Array1::zeros(self.cols()); From 97d52e742d576081bf1dc7e30fc3eff72be621f8 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 10 Jan 2025 19:18:20 -0800 Subject: [PATCH 11/12] fix ownership --- Cargo.toml | 2 +- algorithms/linfa-kernel/Cargo.toml | 2 +- algorithms/linfa-kernel/src/inner.rs | 28 +++++++++++++++++++++-- algorithms/linfa-preprocessing/Cargo.toml | 2 +- 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c5552775e..1093c7616 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ approx = "0.4" ndarray = { version = "0.15", features = ["approx"] } ndarray-linalg = { version = "0.16", optional = true } -sprs = { version = "0.11", default-features = false } +sprs = { version = "=0.11.1", default-features = false } thiserror = "1.0" diff --git a/algorithms/linfa-kernel/Cargo.toml b/algorithms/linfa-kernel/Cargo.toml index 4646cea27..2d4dba924 100644 --- a/algorithms/linfa-kernel/Cargo.toml +++ b/algorithms/linfa-kernel/Cargo.toml @@ -26,7 +26,7 @@ features = ["std", "derive"] [dependencies] ndarray = "0.15" num-traits = "0.2" -sprs = { version="0.11", default-features = false } +sprs = { version="0.11.1", default-features = false } linfa = { version = "0.7.0", path = "../.." } linfa-nn = { version = "0.7.0", path = "../linfa-nn" } diff --git a/algorithms/linfa-kernel/src/inner.rs b/algorithms/linfa-kernel/src/inner.rs index 00c712878..fda7ffcaa 100644 --- a/algorithms/linfa-kernel/src/inner.rs +++ b/algorithms/linfa-kernel/src/inner.rs @@ -61,7 +61,19 @@ impl Inner for CsMat { type Elem = F; fn dot(&self, rhs: &ArrayView2) -> Array2 { - self.mul(&rhs.to_owned()) + let mut result = Array2::zeros((self.rows(), rhs.ncols())); + + // Handle potential sparse matrices + for j in 0..rhs.ncols() { + let col = rhs.column(j); + let col_result = self.mul(&col.to_owned()); + // Copy result into appropriate column of output + for i in 0..self.rows() { + result[[i, j]] = col_result[i]; + } + } + + result } fn sum(&self) -> Array1 { let mut sum = Array1::zeros(self.cols()); @@ -106,7 +118,19 @@ impl<'a, F: Float> Inner for CsMatView<'a, F> { type Elem = F; fn dot(&self, rhs: &ArrayView2) -> Array2 { - self.mul(&rhs.to_owned()) + let mut result = Array2::zeros((self.rows(), rhs.ncols())); + + // Handle potential sparse matrices + for j in 0..rhs.ncols() { + let col = rhs.column(j); + let col_result = self.mul(&col.to_owned()); + // Copy result into appropriate column of output + for i in 0..self.rows() { + result[[i, j]] = col_result[i]; + } + } + + result } fn sum(&self) -> Array1 { let mut sum = Array1::zeros(self.cols()); diff --git a/algorithms/linfa-preprocessing/Cargo.toml b/algorithms/linfa-preprocessing/Cargo.toml index 6f9687aa8..637ee9b5c 100644 --- a/algorithms/linfa-preprocessing/Cargo.toml +++ b/algorithms/linfa-preprocessing/Cargo.toml @@ -29,7 +29,7 @@ ndarray-rand = { version = "0.14" } unicode-normalization = "0.1.8" regex = "1.4.5" encoding = "0.2" -sprs = { version = "0.11.0", default-features = false } +sprs = { version = "0.11.1", default-features = false } serde_regex = { version = "1.1", optional = true } From d4a5744e94249a503f6088dcb4aae264eeff7e96 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Sat, 11 Jan 2025 14:27:01 -0800 Subject: [PATCH 12/12] cleanup lints --- algorithms/linfa-kernel/Cargo.toml | 2 +- algorithms/linfa-kernel/src/inner.rs | 28 ++--------------------- algorithms/linfa-preprocessing/Cargo.toml | 2 +- 3 files changed, 4 insertions(+), 28 deletions(-) diff --git a/algorithms/linfa-kernel/Cargo.toml b/algorithms/linfa-kernel/Cargo.toml index 2d4dba924..266ce1e0e 100644 --- a/algorithms/linfa-kernel/Cargo.toml +++ b/algorithms/linfa-kernel/Cargo.toml @@ -26,7 +26,7 @@ features = ["std", "derive"] [dependencies] ndarray = "0.15" num-traits = "0.2" -sprs = { version="0.11.1", default-features = false } +sprs = { version="=0.11.1", default-features = false } linfa = { version = "0.7.0", path = "../.." } linfa-nn = { version = "0.7.0", path = "../linfa-nn" } diff --git a/algorithms/linfa-kernel/src/inner.rs b/algorithms/linfa-kernel/src/inner.rs index fda7ffcaa..a0bdfb036 100644 --- a/algorithms/linfa-kernel/src/inner.rs +++ b/algorithms/linfa-kernel/src/inner.rs @@ -61,19 +61,7 @@ impl Inner for CsMat { type Elem = F; fn dot(&self, rhs: &ArrayView2) -> Array2 { - let mut result = Array2::zeros((self.rows(), rhs.ncols())); - - // Handle potential sparse matrices - for j in 0..rhs.ncols() { - let col = rhs.column(j); - let col_result = self.mul(&col.to_owned()); - // Copy result into appropriate column of output - for i in 0..self.rows() { - result[[i, j]] = col_result[i]; - } - } - - result + self.mul(rhs) } fn sum(&self) -> Array1 { let mut sum = Array1::zeros(self.cols()); @@ -118,19 +106,7 @@ impl<'a, F: Float> Inner for CsMatView<'a, F> { type Elem = F; fn dot(&self, rhs: &ArrayView2) -> Array2 { - let mut result = Array2::zeros((self.rows(), rhs.ncols())); - - // Handle potential sparse matrices - for j in 0..rhs.ncols() { - let col = rhs.column(j); - let col_result = self.mul(&col.to_owned()); - // Copy result into appropriate column of output - for i in 0..self.rows() { - result[[i, j]] = col_result[i]; - } - } - - result + self.mul(rhs) } fn sum(&self) -> Array1 { let mut sum = Array1::zeros(self.cols()); diff --git a/algorithms/linfa-preprocessing/Cargo.toml b/algorithms/linfa-preprocessing/Cargo.toml index 637ee9b5c..25f030b04 100644 --- a/algorithms/linfa-preprocessing/Cargo.toml +++ b/algorithms/linfa-preprocessing/Cargo.toml @@ -29,7 +29,7 @@ ndarray-rand = { version = "0.14" } unicode-normalization = "0.1.8" regex = "1.4.5" encoding = "0.2" -sprs = { version = "0.11.1", default-features = false } +sprs = { version = "=0.11.1", default-features = false } serde_regex = { version = "1.1", optional = true }