Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 174d822

Browse filesBrowse files
jeremiedbbogrisel
authored andcommitted
DictionaryLearning: Fix several issues in the dict update (scikit-learn#19198)
Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com>
1 parent cf94462 commit 174d822
Copy full SHA for 174d822

File tree

Expand file treeCollapse file tree

3 files changed

+115
-84
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+115
-84
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+14Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,25 @@ Changelog
159159
- |Fix| Fixes incorrect multiple data-conversion warnings when clustering
160160
boolean data. :pr:`19046` by :user:`Surya Prakash <jdsurya>`.
161161

162+
:mod:`sklearn.decomposition`
163+
............................
164+
162165
- |Fix| Fixed :func:`dict_learning`, used by :class:`DictionaryLearning`, to
163166
ensure determinism of the output. Achieved by flipping signs of the SVD
164167
output which is used to initialize the code.
165168
:pr:`18433` by :user:`Bruno Charron <brcharron>`.
166169

170+
- |Fix| Fixed a bug in :class:`MiniBatchDictionaryLearning`,
171+
:class:`MiniBatchSparsePCA` and :func:`dict_learning_online` where the
172+
update of the dictionary was incorrect. :pr:`19198` by
173+
:user:`Jérémie du Boisberranger <jeremiedbb>`.
174+
175+
- |Fix| Fixed a bug in :class:`DictionaryLearning`, :class:`SparsePCA`,
176+
:class:`MiniBatchDictionaryLearning`, :class:`MiniBatchSparsePCA`,
177+
:func:`dict_learning` and :func:`dict_learning_online` where the restart of
178+
unused atoms during the dictionary update was not working as expected.
179+
:pr:`19198` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
180+
167181
:mod:`sklearn.ensemble`
168182
.......................
169183

‎sklearn/decomposition/_dict_learning.py

Copy file name to clipboardExpand all lines: sklearn/decomposition/_dict_learning.py
+73-84Lines changed: 73 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -355,28 +355,32 @@ def sparse_encode(X, dictionary, *, gram=None, cov=None,
355355
return code
356356

357357

358-
def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
358+
def _update_dict(dictionary, Y, code, A=None, B=None, verbose=False,
359359
random_state=None, positive=False):
360360
"""Update the dense dictionary factor in place.
361361
362362
Parameters
363363
----------
364-
dictionary : ndarray of shape (n_features, n_components)
364+
dictionary : ndarray of shape (n_components, n_features)
365365
Value of the dictionary at the previous iteration.
366366
367-
Y : ndarray of shape (n_features, n_samples)
367+
Y : ndarray of shape (n_samples, n_features)
368368
Data matrix.
369369
370-
code : ndarray of shape (n_components, n_samples)
370+
code : ndarray of shape (n_samples, n_components)
371371
Sparse coding of the data against which to optimize the dictionary.
372372
373+
A : ndarray of shape (n_components, n_components), default=None
374+
Together with `B`, sufficient stats of the online model to update the
375+
dictionary.
376+
377+
B : ndarray of shape (n_features, n_components), default=None
378+
Together with `A`, sufficient stats of the online model to update the
379+
dictionary.
380+
373381
verbose: bool, default=False
374382
Degree of output the procedure will print.
375383
376-
return_r2 : bool, default=False
377-
Whether to compute and return the residual sum of squares corresponding
378-
to the computed solution.
379-
380384
random_state : int, RandomState instance or None, default=None
381385
Used for randomly initializing the dictionary. Pass an int for
382386
reproducible results across multiple function calls.
@@ -386,54 +390,41 @@ def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
386390
Whether to enforce positivity when finding the dictionary.
387391
388392
.. versionadded:: 0.20
389-
390-
Returns
391-
-------
392-
dictionary : ndarray of shape (n_features, n_components)
393-
Updated dictionary.
394393
"""
395-
n_components = len(code)
396-
n_features = Y.shape[0]
394+
n_samples, n_components = code.shape
397395
random_state = check_random_state(random_state)
398-
# Get BLAS functions
399-
gemm, = linalg.get_blas_funcs(('gemm',), (dictionary, code, Y))
400-
ger, = linalg.get_blas_funcs(('ger',), (dictionary, code))
401-
nrm2, = linalg.get_blas_funcs(('nrm2',), (dictionary,))
402-
# Residuals, computed with BLAS for speed and efficiency
403-
# R <- -1.0 * U * V^T + 1.0 * Y
404-
# Outputs R as Fortran array for efficiency
405-
R = gemm(-1.0, dictionary, code, 1.0, Y)
396+
397+
if A is None:
398+
A = code.T @ code
399+
if B is None:
400+
B = Y.T @ code
401+
402+
n_unused = 0
403+
406404
for k in range(n_components):
407-
# R <- 1.0 * U_k * V_k^T + R
408-
R = ger(1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
409-
dictionary[:, k] = np.dot(R, code[k, :])
410-
if positive:
411-
np.clip(dictionary[:, k], 0, None, out=dictionary[:, k])
412-
# Scale k'th atom
413-
# (U_k * U_k) ** 0.5
414-
atom_norm = nrm2(dictionary[:, k])
415-
if atom_norm < 1e-10:
416-
if verbose == 1:
417-
sys.stdout.write("+")
418-
sys.stdout.flush()
419-
elif verbose:
420-
print("Adding new random atom")
421-
dictionary[:, k] = random_state.randn(n_features)
422-
if positive:
423-
np.clip(dictionary[:, k], 0, None, out=dictionary[:, k])
424-
# Setting corresponding coefs to 0
425-
code[k, :] = 0.0
426-
# (U_k * U_k) ** 0.5
427-
atom_norm = nrm2(dictionary[:, k])
428-
dictionary[:, k] /= atom_norm
405+
if A[k, k] > 1e-6:
406+
# 1e-6 is arbitrary but consistent with the spams implementation
407+
dictionary[k] += (B[:, k] - A[k] @ dictionary) / A[k, k]
429408
else:
430-
dictionary[:, k] /= atom_norm
431-
# R <- -1.0 * U_k * V_k^T + R
432-
R = ger(-1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
433-
if return_r2:
434-
R = nrm2(R) ** 2.0
435-
return dictionary, R
436-
return dictionary
409+
# kth atom is almost never used -> sample a new one from the data
410+
newd = Y[random_state.choice(n_samples)]
411+
412+
# add small noise to avoid making the sparse coding ill conditioned
413+
noise_level = 0.01 * (newd.std() or 1) # avoid 0 std
414+
noise = random_state.normal(0, noise_level, size=len(newd))
415+
416+
dictionary[k] = newd + noise
417+
code[:, k] = 0
418+
n_unused += 1
419+
420+
if positive:
421+
np.clip(dictionary[k], 0, None, out=dictionary[k])
422+
423+
# Projection on the constraint set ||V_k|| == 1
424+
dictionary[k] /= linalg.norm(dictionary[k])
425+
426+
if verbose and n_unused > 0:
427+
print(f"{n_unused} unused atoms resampled.")
437428

438429

439430
@_deprecate_positional_args
@@ -579,10 +570,9 @@ def dict_learning(X, n_components, *, alpha, max_iter=100, tol=1e-8,
579570
dictionary = np.r_[dictionary,
580571
np.zeros((n_components - r, dictionary.shape[1]))]
581572

582-
# Fortran-order dict, as we are going to access its row vectors
583-
dictionary = np.array(dictionary, order='F')
584-
585-
residuals = 0
573+
# Fortran-order dict better suited for the sparse coding which is the
574+
# bottleneck of this algorithm.
575+
dictionary = np.asfortranarray(dictionary)
586576

587577
errors = []
588578
current_cost = np.nan
@@ -607,15 +597,14 @@ def dict_learning(X, n_components, *, alpha, max_iter=100, tol=1e-8,
607597
code = sparse_encode(X, dictionary, algorithm=method, alpha=alpha,
608598
init=code, n_jobs=n_jobs, positive=positive_code,
609599
max_iter=method_max_iter, verbose=verbose)
610-
# Update dictionary
611-
dictionary, residuals = _update_dict(dictionary.T, X.T, code.T,
612-
verbose=verbose, return_r2=True,
613-
random_state=random_state,
614-
positive=positive_dict)
615-
dictionary = dictionary.T
600+
601+
# Update dictionary in place
602+
_update_dict(dictionary, X, code, verbose=verbose,
603+
random_state=random_state, positive=positive_dict)
616604

617605
# Cost function
618-
current_cost = 0.5 * residuals + alpha * np.sum(np.abs(code))
606+
current_cost = (0.5 * np.sum((X - code @ dictionary)**2)
607+
+ alpha * np.sum(np.abs(code)))
619608
errors.append(current_cost)
620609

621610
if ii > 0:
@@ -807,7 +796,9 @@ def dict_learning_online(X, n_components=2, *, alpha=1, n_iter=100,
807796
else:
808797
X_train = X
809798

810-
dictionary = check_array(dictionary.T, order='F', dtype=np.float64,
799+
# Fortran-order dict better suited for the sparse coding which is the
800+
# bottleneck of this algorithm.
801+
dictionary = check_array(dictionary, order='F', dtype=np.float64,
811802
copy=False)
812803
dictionary = np.require(dictionary, requirements='W')
813804

@@ -839,11 +830,11 @@ def dict_learning_online(X, n_components=2, *, alpha=1, n_iter=100,
839830
print("Iteration % 3i (elapsed time: % 3is, % 4.1fmn)"
840831
% (ii, dt, dt / 60))
841832

842-
this_code = sparse_encode(this_X, dictionary.T, algorithm=method,
833+
this_code = sparse_encode(this_X, dictionary, algorithm=method,
843834
alpha=alpha, n_jobs=n_jobs,
844835
check_input=False,
845836
positive=positive_code,
846-
max_iter=method_max_iter, verbose=verbose).T
837+
max_iter=method_max_iter, verbose=verbose)
847838

848839
# Update the auxiliary variables
849840
if ii < batch_size - 1:
@@ -853,15 +844,13 @@ def dict_learning_online(X, n_components=2, *, alpha=1, n_iter=100,
853844
beta = (theta + 1 - batch_size) / (theta + 1)
854845

855846
A *= beta
856-
A += np.dot(this_code, this_code.T)
847+
A += np.dot(this_code.T, this_code)
857848
B *= beta
858-
B += np.dot(this_X.T, this_code.T)
849+
B += np.dot(this_X.T, this_code)
859850

860-
# Update dictionary
861-
dictionary = _update_dict(dictionary, B, A, verbose=verbose,
862-
random_state=random_state,
863-
positive=positive_dict)
864-
# XXX: Can the residuals be of any use?
851+
# Update dictionary in place
852+
_update_dict(dictionary, this_X, this_code, A, B, verbose=verbose,
853+
random_state=random_state, positive=positive_dict)
865854

866855
# Maybe we need a stopping criteria based on the amount of
867856
# modification in the dictionary
@@ -870,30 +859,30 @@ def dict_learning_online(X, n_components=2, *, alpha=1, n_iter=100,
870859

871860
if return_inner_stats:
872861
if return_n_iter:
873-
return dictionary.T, (A, B), ii - iter_offset + 1
862+
return dictionary, (A, B), ii - iter_offset + 1
874863
else:
875-
return dictionary.T, (A, B)
864+
return dictionary, (A, B)
876865
if return_code:
877866
if verbose > 1:
878867
print('Learning code...', end=' ')
879868
elif verbose == 1:
880869
print('|', end=' ')
881-
code = sparse_encode(X, dictionary.T, algorithm=method, alpha=alpha,
870+
code = sparse_encode(X, dictionary, algorithm=method, alpha=alpha,
882871
n_jobs=n_jobs, check_input=False,
883872
positive=positive_code, max_iter=method_max_iter,
884873
verbose=verbose)
885874
if verbose > 1:
886875
dt = (time.time() - t0)
887876
print('done (total time: % 3is, % 4.1fmn)' % (dt, dt / 60))
888877
if return_n_iter:
889-
return code, dictionary.T, ii - iter_offset + 1
878+
return code, dictionary, ii - iter_offset + 1
890879
else:
891-
return code, dictionary.T
880+
return code, dictionary
892881

893882
if return_n_iter:
894-
return dictionary.T, ii - iter_offset + 1
883+
return dictionary, ii - iter_offset + 1
895884
else:
896-
return dictionary.T
885+
return dictionary
897886

898887

899888
class _BaseSparseCoding(TransformerMixin):
@@ -1286,15 +1275,15 @@ class DictionaryLearning(_BaseSparseCoding, BaseEstimator):
12861275
We can check the level of sparsity of `X_transformed`:
12871276
12881277
>>> np.mean(X_transformed == 0)
1289-
0.88...
1278+
0.87...
12901279
12911280
We can compare the average squared euclidean norm of the reconstruction
12921281
error of the sparse coded signal relative to the squared euclidean norm of
12931282
the original signal:
12941283
12951284
>>> X_hat = X_transformed @ dict_learner.components_
12961285
>>> np.mean(np.sum((X_hat - X) ** 2, axis=1) / np.sum(X ** 2, axis=1))
1297-
0.07...
1286+
0.08...
12981287
12991288
Notes
13001289
-----
@@ -1523,15 +1512,15 @@ class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
15231512
We can check the level of sparsity of `X_transformed`:
15241513
15251514
>>> np.mean(X_transformed == 0)
1526-
0.87...
1515+
0.86...
15271516
15281517
We can compare the average squared euclidean norm of the reconstruction
15291518
error of the sparse coded signal relative to the squared euclidean norm of
15301519
the original signal:
15311520
15321521
>>> X_hat = X_transformed @ dict_learner.components_
15331522
>>> np.mean(np.sum((X_hat - X) ** 2, axis=1) / np.sum(X ** 2, axis=1))
1534-
0.10...
1523+
0.07...
15351524
15361525
Notes
15371526
-----

‎sklearn/decomposition/tests/test_dict_learning.py

Copy file name to clipboardExpand all lines: sklearn/decomposition/tests/test_dict_learning.py
+28Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from sklearn.utils import check_array
1212

13+
from sklearn.utils._testing import assert_allclose
1314
from sklearn.utils._testing import assert_array_almost_equal
1415
from sklearn.utils._testing import assert_array_equal
1516
from sklearn.utils._testing import ignore_warnings
@@ -25,6 +26,8 @@
2526
from sklearn.utils.estimator_checks import check_transformer_general
2627
from sklearn.utils.estimator_checks import check_transformers_unfitted
2728

29+
from sklearn.decomposition._dict_learning import _update_dict
30+
2831

2932
rng_global = np.random.RandomState(0)
3033
n_samples, n_features = 10, 8
@@ -575,6 +578,31 @@ def test_sparse_coder_n_features_in():
575578
assert sc.n_features_in_ == d.shape[1]
576579

577580

581+
def test_update_dict():
582+
# Check the dict update in batch mode vs online mode
583+
# Non-regression test for #4866
584+
rng = np.random.RandomState(0)
585+
586+
code = np.array([[0.5, -0.5],
587+
[0.1, 0.9]])
588+
dictionary = np.array([[1., 0.],
589+
[0.6, 0.8]])
590+
591+
X = np.dot(code, dictionary) + rng.randn(2, 2)
592+
593+
# full batch update
594+
newd_batch = dictionary.copy()
595+
_update_dict(newd_batch, X, code)
596+
597+
# online update
598+
A = np.dot(code.T, code)
599+
B = np.dot(X.T, code)
600+
newd_online = dictionary.copy()
601+
_update_dict(newd_online, X, code, A, B)
602+
603+
assert_allclose(newd_batch, newd_online)
604+
605+
578606
@pytest.mark.parametrize("Estimator", [DictionaryLearning,
579607
MiniBatchDictionaryLearning])
580608
def test_warning_default_transform_alpha(Estimator):

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.