Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 4fd851c

Browse filesBrowse files
authored
FIX Ensure determinism of SVD init in dict_learning (#18433)
1 parent 50d3aaa commit 4fd851c
Copy full SHA for 4fd851c

File tree

2 files changed

+9
-2
lines changed
Filter options

2 files changed

+9
-2
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+6-1Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,18 @@ Changelog
6565
- |API| In :class:`decomposition.DictionaryLearning`,
6666
:class:`decomposition.MiniBatchDictionaryLearning`,
6767
:func:`dict_learning` and :func:`dict_learning_online`,
68-
`transform_alpha` will be equal to `alpha` instead of 1.0 by default
68+
`transform_alpha` will be equal to `alpha` instead of 1.0 by default
6969
starting from version 1.2
7070
:pr:`19159` by :user:`Benoît Malézieux <bmalezieux>`.
7171

7272
- |Fix| Fixes incorrect multiple data-conversion warnings when clustering
7373
boolean data. :pr:`19046` by :user:`Surya Prakash <jdsurya>`.
7474

75+
- |Fix| Fixed :func:`dict_learning`, used by :class:`DictionaryLearning`, to
76+
ensure determinism of the output. Achieved by flipping signs of the SVD
77+
output which is used to initialize the code.
78+
:pr:`18433` by :user:`Bruno Charron <brcharron>`.
79+
7580
:mod:`sklearn.ensemble`
7681
.......................
7782

‎sklearn/decomposition/_dict_learning.py

Copy file name to clipboardExpand all lines: sklearn/decomposition/_dict_learning.py
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..utils import deprecated
1919
from ..utils import (check_array, check_random_state, gen_even_slices,
2020
gen_batches)
21-
from ..utils.extmath import randomized_svd, row_norms
21+
from ..utils.extmath import randomized_svd, row_norms, svd_flip
2222
from ..utils.validation import check_is_fitted, _deprecate_positional_args
2323
from ..utils.fixes import delayed
2424
from ..linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars
@@ -567,6 +567,8 @@ def dict_learning(X, n_components, *, alpha, max_iter=100, tol=1e-8,
567567
dictionary = dict_init
568568
else:
569569
code, S, dictionary = linalg.svd(X, full_matrices=False)
570+
# flip the initial code's sign to enforce deterministic output
571+
code, dictionary = svd_flip(code, dictionary)
570572
dictionary = S[:, np.newaxis] * dictionary
571573
r = len(dictionary)
572574
if n_components <= r: # True even if n_components=None

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.