@@ -355,28 +355,32 @@ def sparse_encode(X, dictionary, *, gram=None, cov=None,
355
355
return code
356
356
357
357
358
- def _update_dict (dictionary , Y , code , verbose = False , return_r2 = False ,
358
+ def _update_dict (dictionary , Y , code , A = None , B = None , verbose = False ,
359
359
random_state = None , positive = False ):
360
360
"""Update the dense dictionary factor in place.
361
361
362
362
Parameters
363
363
----------
364
- dictionary : ndarray of shape (n_features, n_components )
364
+ dictionary : ndarray of shape (n_components, n_features )
365
365
Value of the dictionary at the previous iteration.
366
366
367
- Y : ndarray of shape (n_features, n_samples )
367
+ Y : ndarray of shape (n_samples, n_features )
368
368
Data matrix.
369
369
370
- code : ndarray of shape (n_components, n_samples )
370
+ code : ndarray of shape (n_samples, n_components )
371
371
Sparse coding of the data against which to optimize the dictionary.
372
372
373
+ A : ndarray of shape (n_components, n_components), default=None
374
+ Together with `B`, sufficient stats of the online model to update the
375
+ dictionary.
376
+
377
+ B : ndarray of shape (n_features, n_components), default=None
378
+ Together with `A`, sufficient stats of the online model to update the
379
+ dictionary.
380
+
373
381
verbose: bool, default=False
374
382
Degree of output the procedure will print.
375
383
376
- return_r2 : bool, default=False
377
- Whether to compute and return the residual sum of squares corresponding
378
- to the computed solution.
379
-
380
384
random_state : int, RandomState instance or None, default=None
381
385
Used for randomly initializing the dictionary. Pass an int for
382
386
reproducible results across multiple function calls.
@@ -386,54 +390,41 @@ def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
386
390
Whether to enforce positivity when finding the dictionary.
387
391
388
392
.. versionadded:: 0.20
389
-
390
- Returns
391
- -------
392
- dictionary : ndarray of shape (n_features, n_components)
393
- Updated dictionary.
394
393
"""
395
- n_components = len (code )
396
- n_features = Y .shape [0 ]
394
+ n_samples , n_components = code .shape
397
395
random_state = check_random_state (random_state )
398
- # Get BLAS functions
399
- gemm , = linalg . get_blas_funcs (( 'gemm' ,), ( dictionary , code , Y ))
400
- ger , = linalg . get_blas_funcs (( 'ger' ,), ( dictionary , code ))
401
- nrm2 , = linalg . get_blas_funcs (( 'nrm2' ,), ( dictionary ,))
402
- # Residuals, computed with BLAS for speed and efficiency
403
- # R <- -1.0 * U * V^T + 1.0 * Y
404
- # Outputs R as Fortran array for efficiency
405
- R = gemm ( - 1.0 , dictionary , code , 1.0 , Y )
396
+
397
+ if A is None :
398
+ A = code . T @ code
399
+ if B is None :
400
+ B = Y . T @ code
401
+
402
+ n_unused = 0
403
+
406
404
for k in range (n_components ):
407
- # R <- 1.0 * U_k * V_k^T + R
408
- R = ger (1.0 , dictionary [:, k ], code [k , :], a = R , overwrite_a = True )
409
- dictionary [:, k ] = np .dot (R , code [k , :])
410
- if positive :
411
- np .clip (dictionary [:, k ], 0 , None , out = dictionary [:, k ])
412
- # Scale k'th atom
413
- # (U_k * U_k) ** 0.5
414
- atom_norm = nrm2 (dictionary [:, k ])
415
- if atom_norm < 1e-10 :
416
- if verbose == 1 :
417
- sys .stdout .write ("+" )
418
- sys .stdout .flush ()
419
- elif verbose :
420
- print ("Adding new random atom" )
421
- dictionary [:, k ] = random_state .randn (n_features )
422
- if positive :
423
- np .clip (dictionary [:, k ], 0 , None , out = dictionary [:, k ])
424
- # Setting corresponding coefs to 0
425
- code [k , :] = 0.0
426
- # (U_k * U_k) ** 0.5
427
- atom_norm = nrm2 (dictionary [:, k ])
428
- dictionary [:, k ] /= atom_norm
405
+ if A [k , k ] > 1e-6 :
406
+ # 1e-6 is arbitrary but consistent with the spams implementation
407
+ dictionary [k ] += (B [:, k ] - A [k ] @ dictionary ) / A [k , k ]
429
408
else :
430
- dictionary [:, k ] /= atom_norm
431
- # R <- -1.0 * U_k * V_k^T + R
432
- R = ger (- 1.0 , dictionary [:, k ], code [k , :], a = R , overwrite_a = True )
433
- if return_r2 :
434
- R = nrm2 (R ) ** 2.0
435
- return dictionary , R
436
- return dictionary
409
+ # kth atom is almost never used -> sample a new one from the data
410
+ newd = Y [random_state .choice (n_samples )]
411
+
412
+ # add small noise to avoid making the sparse coding ill conditioned
413
+ noise_level = 0.01 * (newd .std () or 1 ) # avoid 0 std
414
+ noise = random_state .normal (0 , noise_level , size = len (newd ))
415
+
416
+ dictionary [k ] = newd + noise
417
+ code [:, k ] = 0
418
+ n_unused += 1
419
+
420
+ if positive :
421
+ np .clip (dictionary [k ], 0 , None , out = dictionary [k ])
422
+
423
+ # Projection on the constraint set ||V_k|| == 1
424
+ dictionary [k ] /= linalg .norm (dictionary [k ])
425
+
426
+ if verbose and n_unused > 0 :
427
+ print (f"{ n_unused } unused atoms resampled." )
437
428
438
429
439
430
@_deprecate_positional_args
@@ -579,10 +570,9 @@ def dict_learning(X, n_components, *, alpha, max_iter=100, tol=1e-8,
579
570
dictionary = np .r_ [dictionary ,
580
571
np .zeros ((n_components - r , dictionary .shape [1 ]))]
581
572
582
- # Fortran-order dict, as we are going to access its row vectors
583
- dictionary = np .array (dictionary , order = 'F' )
584
-
585
- residuals = 0
573
+ # Fortran-order dict better suited for the sparse coding which is the
574
+ # bottleneck of this algorithm.
575
+ dictionary = np .asfortranarray (dictionary )
586
576
587
577
errors = []
588
578
current_cost = np .nan
@@ -607,15 +597,14 @@ def dict_learning(X, n_components, *, alpha, max_iter=100, tol=1e-8,
607
597
code = sparse_encode (X , dictionary , algorithm = method , alpha = alpha ,
608
598
init = code , n_jobs = n_jobs , positive = positive_code ,
609
599
max_iter = method_max_iter , verbose = verbose )
610
- # Update dictionary
611
- dictionary , residuals = _update_dict (dictionary .T , X .T , code .T ,
612
- verbose = verbose , return_r2 = True ,
613
- random_state = random_state ,
614
- positive = positive_dict )
615
- dictionary = dictionary .T
600
+
601
+ # Update dictionary in place
602
+ _update_dict (dictionary , X , code , verbose = verbose ,
603
+ random_state = random_state , positive = positive_dict )
616
604
617
605
# Cost function
618
- current_cost = 0.5 * residuals + alpha * np .sum (np .abs (code ))
606
+ current_cost = (0.5 * np .sum ((X - code @ dictionary )** 2 )
607
+ + alpha * np .sum (np .abs (code )))
619
608
errors .append (current_cost )
620
609
621
610
if ii > 0 :
@@ -807,7 +796,9 @@ def dict_learning_online(X, n_components=2, *, alpha=1, n_iter=100,
807
796
else :
808
797
X_train = X
809
798
810
- dictionary = check_array (dictionary .T , order = 'F' , dtype = np .float64 ,
799
+ # Fortran-order dict better suited for the sparse coding which is the
800
+ # bottleneck of this algorithm.
801
+ dictionary = check_array (dictionary , order = 'F' , dtype = np .float64 ,
811
802
copy = False )
812
803
dictionary = np .require (dictionary , requirements = 'W' )
813
804
@@ -839,11 +830,11 @@ def dict_learning_online(X, n_components=2, *, alpha=1, n_iter=100,
839
830
print ("Iteration % 3i (elapsed time: % 3is, % 4.1fmn)"
840
831
% (ii , dt , dt / 60 ))
841
832
842
- this_code = sparse_encode (this_X , dictionary . T , algorithm = method ,
833
+ this_code = sparse_encode (this_X , dictionary , algorithm = method ,
843
834
alpha = alpha , n_jobs = n_jobs ,
844
835
check_input = False ,
845
836
positive = positive_code ,
846
- max_iter = method_max_iter , verbose = verbose ). T
837
+ max_iter = method_max_iter , verbose = verbose )
847
838
848
839
# Update the auxiliary variables
849
840
if ii < batch_size - 1 :
@@ -853,15 +844,13 @@ def dict_learning_online(X, n_components=2, *, alpha=1, n_iter=100,
853
844
beta = (theta + 1 - batch_size ) / (theta + 1 )
854
845
855
846
A *= beta
856
- A += np .dot (this_code , this_code . T )
847
+ A += np .dot (this_code . T , this_code )
857
848
B *= beta
858
- B += np .dot (this_X .T , this_code . T )
849
+ B += np .dot (this_X .T , this_code )
859
850
860
- # Update dictionary
861
- dictionary = _update_dict (dictionary , B , A , verbose = verbose ,
862
- random_state = random_state ,
863
- positive = positive_dict )
864
- # XXX: Can the residuals be of any use?
851
+ # Update dictionary in place
852
+ _update_dict (dictionary , this_X , this_code , A , B , verbose = verbose ,
853
+ random_state = random_state , positive = positive_dict )
865
854
866
855
# Maybe we need a stopping criteria based on the amount of
867
856
# modification in the dictionary
@@ -870,30 +859,30 @@ def dict_learning_online(X, n_components=2, *, alpha=1, n_iter=100,
870
859
871
860
if return_inner_stats :
872
861
if return_n_iter :
873
- return dictionary . T , (A , B ), ii - iter_offset + 1
862
+ return dictionary , (A , B ), ii - iter_offset + 1
874
863
else :
875
- return dictionary . T , (A , B )
864
+ return dictionary , (A , B )
876
865
if return_code :
877
866
if verbose > 1 :
878
867
print ('Learning code...' , end = ' ' )
879
868
elif verbose == 1 :
880
869
print ('|' , end = ' ' )
881
- code = sparse_encode (X , dictionary . T , algorithm = method , alpha = alpha ,
870
+ code = sparse_encode (X , dictionary , algorithm = method , alpha = alpha ,
882
871
n_jobs = n_jobs , check_input = False ,
883
872
positive = positive_code , max_iter = method_max_iter ,
884
873
verbose = verbose )
885
874
if verbose > 1 :
886
875
dt = (time .time () - t0 )
887
876
print ('done (total time: % 3is, % 4.1fmn)' % (dt , dt / 60 ))
888
877
if return_n_iter :
889
- return code , dictionary . T , ii - iter_offset + 1
878
+ return code , dictionary , ii - iter_offset + 1
890
879
else :
891
- return code , dictionary . T
880
+ return code , dictionary
892
881
893
882
if return_n_iter :
894
- return dictionary . T , ii - iter_offset + 1
883
+ return dictionary , ii - iter_offset + 1
895
884
else :
896
- return dictionary . T
885
+ return dictionary
897
886
898
887
899
888
class _BaseSparseCoding (TransformerMixin ):
@@ -1286,15 +1275,15 @@ class DictionaryLearning(_BaseSparseCoding, BaseEstimator):
1286
1275
We can check the level of sparsity of `X_transformed`:
1287
1276
1288
1277
>>> np.mean(X_transformed == 0)
1289
- 0.88 ...
1278
+ 0.87 ...
1290
1279
1291
1280
We can compare the average squared euclidean norm of the reconstruction
1292
1281
error of the sparse coded signal relative to the squared euclidean norm of
1293
1282
the original signal:
1294
1283
1295
1284
>>> X_hat = X_transformed @ dict_learner.components_
1296
1285
>>> np.mean(np.sum((X_hat - X) ** 2, axis=1) / np.sum(X ** 2, axis=1))
1297
- 0.07 ...
1286
+ 0.08 ...
1298
1287
1299
1288
Notes
1300
1289
-----
@@ -1523,15 +1512,15 @@ class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
1523
1512
We can check the level of sparsity of `X_transformed`:
1524
1513
1525
1514
>>> np.mean(X_transformed == 0)
1526
- 0.87 ...
1515
+ 0.86 ...
1527
1516
1528
1517
We can compare the average squared euclidean norm of the reconstruction
1529
1518
error of the sparse coded signal relative to the squared euclidean norm of
1530
1519
the original signal:
1531
1520
1532
1521
>>> X_hat = X_transformed @ dict_learner.components_
1533
1522
>>> np.mean(np.sum((X_hat - X) ** 2, axis=1) / np.sum(X ** 2, axis=1))
1534
- 0.10 ...
1523
+ 0.07 ...
1535
1524
1536
1525
Notes
1537
1526
-----
0 commit comments