Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fa6ddba

Browse filesBrowse files
authored
TST make sure test_pca_sparse passes on all random seeds (#28861)
1 parent c35a719 commit fa6ddba
Copy full SHA for fa6ddba

File tree

1 file changed

+25
-19
lines changed
Filter options

1 file changed

+25
-19
lines changed

‎sklearn/decomposition/tests/test_pca.py

Copy file name to clipboardExpand all lines: sklearn/decomposition/tests/test_pca.py
+25-19Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,18 @@
3434
SPARSE_MAX_COMPONENTS = min(SPARSE_M, SPARSE_N)
3535

3636

37-
def _check_fitted_pca_close(pca1, pca2, rtol):
38-
assert_allclose(pca1.components_, pca2.components_, rtol=rtol)
39-
assert_allclose(pca1.explained_variance_, pca2.explained_variance_, rtol=rtol)
40-
assert_allclose(pca1.singular_values_, pca2.singular_values_, rtol=rtol)
41-
assert_allclose(pca1.mean_, pca2.mean_, rtol=rtol)
42-
assert_allclose(pca1.n_components_, pca2.n_components_, rtol=rtol)
43-
assert_allclose(pca1.n_samples_, pca2.n_samples_, rtol=rtol)
44-
assert_allclose(pca1.noise_variance_, pca2.noise_variance_, rtol=rtol)
45-
assert_allclose(pca1.n_features_in_, pca2.n_features_in_, rtol=rtol)
37+
def _check_fitted_pca_close(pca1, pca2, rtol=1e-7, atol=1e-12):
38+
assert_allclose(pca1.components_, pca2.components_, rtol=rtol, atol=atol)
39+
assert_allclose(
40+
pca1.explained_variance_, pca2.explained_variance_, rtol=rtol, atol=atol
41+
)
42+
assert_allclose(pca1.singular_values_, pca2.singular_values_, rtol=rtol, atol=atol)
43+
assert_allclose(pca1.mean_, pca2.mean_, rtol=rtol, atol=atol)
44+
assert_allclose(pca1.noise_variance_, pca2.noise_variance_, rtol=rtol, atol=atol)
45+
46+
assert pca1.n_components_ == pca2.n_components_
47+
assert pca1.n_samples_ == pca2.n_samples_
48+
assert pca1.n_features_in_ == pca2.n_features_in_
4649

4750

4851
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
@@ -75,9 +78,12 @@ def test_pca(svd_solver, n_components):
7578
def test_pca_sparse(
7679
global_random_seed, svd_solver, sparse_container, n_components, density, scale
7780
):
78-
# Make sure any tolerance changes pass with SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all"
79-
rtol = 5e-07
80-
transform_rtol = 3e-05
81+
"""Check that the results are the same for sparse and dense input."""
82+
83+
# Set atol in addition of the default rtol to account for the very wide range of
84+
# result values (1e-8 to 1e0).
85+
atol = 1e-12
86+
transform_atol = 1e-10
8187

8288
random_state = np.random.default_rng(global_random_seed)
8389
X = sparse_container(
@@ -108,7 +114,7 @@ def test_pca_sparse(
108114
pcad.fit(Xd)
109115

110116
# Fitted attributes equality
111-
_check_fitted_pca_close(pca, pcad, rtol=rtol)
117+
_check_fitted_pca_close(pca, pcad, atol=atol)
112118

113119
# Test transform
114120
X2 = sparse_container(
@@ -121,8 +127,8 @@ def test_pca_sparse(
121127
)
122128
X2d = X2.toarray()
123129

124-
assert_allclose(pca.transform(X2), pca.transform(X2d), rtol=transform_rtol)
125-
assert_allclose(pca.transform(X2), pcad.transform(X2d), rtol=transform_rtol)
130+
assert_allclose(pca.transform(X2), pca.transform(X2d), atol=transform_atol)
131+
assert_allclose(pca.transform(X2), pcad.transform(X2d), atol=transform_atol)
126132

127133

128134
@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS)
@@ -153,10 +159,10 @@ def test_pca_sparse_fit_transform(global_random_seed, sparse_container):
153159
pca_fit.fit(X)
154160
transformed_X = pca_fit_transform.fit_transform(X)
155161

156-
_check_fitted_pca_close(pca_fit, pca_fit_transform, rtol=1e-10)
157-
assert_allclose(transformed_X, pca_fit_transform.transform(X), rtol=2e-9)
158-
assert_allclose(transformed_X, pca_fit.transform(X), rtol=2e-9)
159-
assert_allclose(pca_fit.transform(X2), pca_fit_transform.transform(X2), rtol=2e-9)
162+
_check_fitted_pca_close(pca_fit, pca_fit_transform)
163+
assert_allclose(transformed_X, pca_fit_transform.transform(X))
164+
assert_allclose(transformed_X, pca_fit.transform(X))
165+
assert_allclose(pca_fit.transform(X2), pca_fit_transform.transform(X2))
160166

161167

162168
@pytest.mark.parametrize("svd_solver", ["randomized", "full"])

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.