Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ceb10e3

Browse filesBrowse files
committed
ENH Support get_precision and get_covariance
1 parent dd4c9fc commit ceb10e3
Copy full SHA for ceb10e3

File tree

4 files changed

+54
-29
lines changed
Filter options

4 files changed

+54
-29
lines changed

‎doc/whats_new/v1.3.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.3.rst
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ Changelog
230230
:class:`decomposition.MiniBatchNMF` which can produce different results than previous
231231
versions. :pr:`25438` by :user:`Yotam Avidar-Constantini <yotamcons>`.
232232

233+
- |Enhancement| :class:`decomposition/PCA` now supports the
234+
`PyTorch <https://pytorch.org/>`__ for `full` solver. See :pr:`26315`
235+
233236
:mod:`sklearn.discriminant_analysis`
234237
....................................
235238

‎sklearn/decomposition/_base.py

Copy file name to clipboardExpand all lines: sklearn/decomposition/_base.py
+36-22Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
# License: BSD 3 clause
1010

1111
import numpy as np
12-
from scipy import linalg
1312

1413
from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin
1514
from ..utils.validation import check_is_fitted
15+
from ..utils._array_api import get_namespace
1616
from abc import ABCMeta, abstractmethod
1717

1818

@@ -37,13 +37,18 @@ def get_covariance(self):
3737
cov : array of shape=(n_features, n_features)
3838
Estimated covariance of data.
3939
"""
40+
xp, _ = get_namespace(self.components_)
41+
4042
components_ = self.components_
4143
exp_var = self.explained_variance_
4244
if self.whiten:
43-
components_ = components_ * np.sqrt(exp_var[:, np.newaxis])
44-
exp_var_diff = np.maximum(exp_var - self.noise_variance_, 0.0)
45-
cov = np.dot(components_.T * exp_var_diff, components_)
46-
cov.flat[:: len(cov) + 1] += self.noise_variance_ # modify diag inplace
45+
components_ = components_ * xp.sqrt(exp_var[:, np.newaxis])
46+
exp_var_diff = xp.maximum(
47+
exp_var - self.noise_variance_, xp.zeros_like(exp_var)
48+
)
49+
cov = (components_.T * exp_var_diff) @ components_
50+
# TODO use views instead?
51+
cov.reshape(-1)[:: len(cov) + 1] += self.noise_variance_ # modify diag inplace
4752
return cov
4853

4954
def get_precision(self):
@@ -57,26 +62,33 @@ def get_precision(self):
5762
precision : array, shape=(n_features, n_features)
5863
Estimated precision of data.
5964
"""
65+
xp, _ = get_namespace(self.components_)
66+
6067
n_features = self.components_.shape[1]
6168

6269
# handle corner cases first
6370
if self.n_components_ == 0:
64-
return np.eye(n_features) / self.noise_variance_
71+
return xp.eye(n_features) / self.noise_variance_
6572

66-
if np.isclose(self.noise_variance_, 0.0, atol=0.0):
67-
return linalg.inv(self.get_covariance())
73+
if xp.isclose(
74+
self.noise_variance_, xp.zeros_like(self.noise_variance_), atol=0.0
75+
):
76+
return xp.linalg.inv(self.get_covariance())
6877

6978
# Get precision using matrix inversion lemma
7079
components_ = self.components_
7180
exp_var = self.explained_variance_
7281
if self.whiten:
73-
components_ = components_ * np.sqrt(exp_var[:, np.newaxis])
74-
exp_var_diff = np.maximum(exp_var - self.noise_variance_, 0.0)
75-
precision = np.dot(components_, components_.T) / self.noise_variance_
76-
precision.flat[:: len(precision) + 1] += 1.0 / exp_var_diff
77-
precision = np.dot(components_.T, np.dot(linalg.inv(precision), components_))
82+
components_ = components_ * xp.sqrt(exp_var[:, np.newaxis])
83+
exp_var_diff = xp.maximum(
84+
exp_var - self.noise_variance_, xp.zeros_like(exp_var)
85+
)
86+
precision = components_ @ components_.T / self.noise_variance_
87+
# TODO use views instead?
88+
precision.reshape(-1)[:: len(precision) + 1] += 1.0 / exp_var_diff
89+
precision = components_.T @ xp.linalg.inv(precision) @ components_
7890
precision /= -(self.noise_variance_**2)
79-
precision.flat[:: len(precision) + 1] += 1.0 / self.noise_variance_
91+
precision.reshape(-1)[:: len(precision) + 1] += 1.0 / self.noise_variance_
8092
return precision
8193

8294
@abstractmethod
@@ -115,14 +127,16 @@ def transform(self, X):
115127
Projection of X in the first principal components, where `n_samples`
116128
is the number of samples and `n_components` is the number of the components.
117129
"""
130+
xp, _ = get_namespace(X)
131+
118132
check_is_fitted(self)
119133

120-
X = self._validate_data(X, dtype=[np.float64, np.float32], reset=False)
134+
X = self._validate_data(X, dtype=[xp.float64, xp.float32], reset=False)
121135
if self.mean_ is not None:
122136
X = X - self.mean_
123-
X_transformed = np.dot(X, self.components_.T)
137+
X_transformed = X @ self.components_.T
124138
if self.whiten:
125-
X_transformed /= np.sqrt(self.explained_variance_)
139+
X_transformed /= xp.sqrt(self.explained_variance_)
126140
return X_transformed
127141

128142
def inverse_transform(self, X):
@@ -147,16 +161,16 @@ def inverse_transform(self, X):
147161
If whitening is enabled, inverse_transform will compute the
148162
exact inverse operation, which includes reversing whitening.
149163
"""
164+
xp, _ = get_namespace(X)
165+
150166
if self.whiten:
151167
return (
152-
np.dot(
153-
X,
154-
np.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_,
155-
)
168+
X
169+
@ (np.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_)
156170
+ self.mean_
157171
)
158172
else:
159-
return np.dot(X, self.components_) + self.mean_
173+
return X @ self.components_ + self.mean_
160174

161175
@property
162176
def _n_features_out(self):

‎sklearn/decomposition/_pca.py

Copy file name to clipboardExpand all lines: sklearn/decomposition/_pca.py
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,6 @@ def _fit(self, X):
489489
"TruncatedSVD for a possible alternative."
490490
)
491491
# Raise an error for torch input and arpack or randomized solver.
492-
# TODO support randomized solver for torch tensors
493492
if self.svd_solver in ["arpack", "randomized"] and _is_torch_namespace(xp):
494493
raise TypeError(self._pca_torch_arpack_solver_error_message)
495494

‎sklearn/decomposition/tests/test_pca.py

Copy file name to clipboardExpand all lines: sklearn/decomposition/tests/test_pca.py
+15-6Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,22 @@ def test_pca_array_torch(device, dtype, svd_solver, n_components):
5050
X_transformed_torch = pca_torch.fit_transform(X_torch)
5151
X_transformed_np = pca_np.fit_transform(X_np)
5252

53-
assert type(X_transformed_np) == np.ndarray, "Invalid type"
54-
assert type(X_transformed_torch) == torch.Tensor, "Invalid type"
55-
assert_allclose(X_transformed_np, X_transformed_torch, atol=1e-3)
53+
cov_np = pca_np.get_covariance()
54+
cov_torch = pca_torch.get_covariance()
5655

57-
# TODO introduce pytorch support for below methods
58-
# cov = pca.get_covariance()
59-
# precision = pca.get_precision()
56+
precision_np = pca_np.get_precision()
57+
precision_torch = pca_torch.get_precision()
58+
59+
for name, arr_np, arr_torch in zip(
60+
["X", "cov", "prec"],
61+
[X_transformed_np, cov_np, precision_np],
62+
[X_transformed_torch, cov_torch, precision_torch],
63+
):
64+
assert type(arr_np) == np.ndarray, f"Invalid type for {name}"
65+
assert type(arr_torch) == torch.Tensor, f"Invalid type for {name}"
66+
assert_allclose(
67+
arr_np, arr_torch, atol=1e-3, err_msg=f"Divergent values for {name}"
68+
)
6069

6170

6271
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.