Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 52e89a4

Browse filesBrowse files
authored
MAINT Clean deprecations in MiniBatchDictionaryLearning (#25357)
1 parent f5ae73d commit 52e89a4
Copy full SHA for 52e89a4

File tree

Expand file treeCollapse file tree

2 files changed

+36
-162
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+36
-162
lines changed

‎sklearn/decomposition/_dict_learning.py

Copy file name to clipboardExpand all lines: sklearn/decomposition/_dict_learning.py
+30-116Lines changed: 30 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin
1919
from ..utils import check_array, check_random_state, gen_even_slices, gen_batches
20-
from ..utils import deprecated
2120
from ..utils._param_validation import Hidden, Interval, StrOptions
2221
from ..utils._param_validation import validate_params
2322
from ..utils.extmath import randomized_svd, row_norms, svd_flip
@@ -627,7 +626,7 @@ def _dict_learning(
627626
def _check_warn_deprecated(param, name, default, additional_message=None):
628627
if param != "deprecated":
629628
msg = (
630-
f"'{name}' is deprecated in version 1.1 and will be removed in version 1.3."
629+
f"'{name}' is deprecated in version 1.1 and will be removed in version 1.4."
631630
)
632631
if additional_message:
633632
msg += f" {additional_message}"
@@ -647,7 +646,7 @@ def dict_learning_online(
647646
return_code=True,
648647
dict_init=None,
649648
callback=None,
650-
batch_size="warn",
649+
batch_size=256,
651650
verbose=False,
652651
shuffle=True,
653652
n_jobs=None,
@@ -717,11 +716,11 @@ def dict_learning_online(
717716
callback : callable, default=None
718717
A callable that gets invoked at the end of each iteration.
719718
720-
batch_size : int, default=3
719+
batch_size : int, default=256
721720
The number of samples to take in each batch.
722721
723722
.. versionchanged:: 1.3
724-
The default value of `batch_size` will change from 3 to 256 in version 1.3.
723+
The default value of `batch_size` changed from 3 to 256 in version 1.3.
725724
726725
verbose : bool, default=False
727726
To control the verbosity of the procedure.
@@ -747,7 +746,7 @@ def dict_learning_online(
747746
initialization.
748747
749748
.. deprecated:: 1.1
750-
`iter_offset` serves internal purpose only and will be removed in 1.3.
749+
`iter_offset` serves internal purpose only and will be removed in 1.4.
751750
752751
random_state : int, RandomState instance or None, default=None
753752
Used for initializing the dictionary when ``dict_init`` is not
@@ -763,7 +762,7 @@ def dict_learning_online(
763762
ignored.
764763
765764
.. deprecated:: 1.1
766-
`return_inner_stats` serves internal purpose only and will be removed in 1.3.
765+
`return_inner_stats` serves internal purpose only and will be removed in 1.4.
767766
768767
inner_stats : tuple of (A, B) ndarrays, default=None
769768
Inner sufficient statistics that are kept by the algorithm.
@@ -773,13 +772,13 @@ def dict_learning_online(
773772
`B` `(n_features, n_components)` is the data approximation matrix.
774773
775774
.. deprecated:: 1.1
776-
`inner_stats` serves internal purpose only and will be removed in 1.3.
775+
`inner_stats` serves internal purpose only and will be removed in 1.4.
777776
778777
return_n_iter : bool, default=False
779778
Whether or not to return the number of iterations.
780779
781780
.. deprecated:: 1.1
782-
`return_n_iter` will be removed in 1.3 and n_iter will always be returned.
781+
`return_n_iter` will be removed in 1.4 and n_iter will never be returned.
783782
784783
positive_dict : bool, default=False
785784
Whether to enforce positivity when finding the dictionary.
@@ -848,15 +847,15 @@ def dict_learning_online(
848847
return_inner_stats,
849848
"return_inner_stats",
850849
default=False,
851-
additional_message="From 1.3 inner_stats will never be returned.",
850+
additional_message="From 1.4 inner_stats will never be returned.",
852851
)
853852
inner_stats = _check_warn_deprecated(inner_stats, "inner_stats", default=None)
854853
return_n_iter = _check_warn_deprecated(
855854
return_n_iter,
856855
"return_n_iter",
857856
default=False,
858857
additional_message=(
859-
"From 1.3 'n_iter' will never be returned. Refer to the 'n_iter_' and "
858+
"From 1.4 'n_iter' will never be returned. Refer to the 'n_iter_' and "
860859
"'n_steps_' attributes of the MiniBatchDictionaryLearning object instead."
861860
),
862861
)
@@ -891,20 +890,13 @@ def dict_learning_online(
891890
code = est.transform(X)
892891
return code, est.components_
893892

894-
# TODO remove the whole old behavior in 1.3
893+
# TODO(1.4) remove the whole old behavior
895894
# Fallback to old behavior
896895

897896
n_iter = _check_warn_deprecated(
898897
n_iter, "n_iter", default=100, additional_message="Use 'max_iter' instead."
899898
)
900899

901-
if batch_size == "warn":
902-
warnings.warn(
903-
"The default value of batch_size will change from 3 to 256 in 1.3.",
904-
FutureWarning,
905-
)
906-
batch_size = 3
907-
908900
if n_components is None:
909901
n_components = X.shape[1]
910902

@@ -1903,11 +1895,11 @@ class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
19031895
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
19041896
for more details.
19051897
1906-
batch_size : int, default=3
1898+
batch_size : int, default=256
19071899
Number of samples in each mini-batch.
19081900
19091901
.. versionchanged:: 1.3
1910-
The default value of `batch_size` will change from 3 to 256 in version 1.3.
1902+
The default value of `batch_size` changed from 3 to 256 in version 1.3.
19111903
19121904
shuffle : bool, default=True
19131905
Whether to shuffle the samples before forming batches.
@@ -2006,17 +1998,6 @@ class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
20061998
components_ : ndarray of shape (n_components, n_features)
20071999
Components extracted from the data.
20082000
2009-
inner_stats_ : tuple of (A, B) ndarrays
2010-
Internal sufficient statistics that are kept by the algorithm.
2011-
Keeping them is useful in online settings, to avoid losing the
2012-
history of the evolution, but they shouldn't have any use for the
2013-
end user.
2014-
`A` `(n_components, n_components)` is the dictionary covariance matrix.
2015-
`B` `(n_features, n_components)` is the data approximation matrix.
2016-
2017-
.. deprecated:: 1.1
2018-
`inner_stats_` serves internal purpose only and will be removed in 1.3.
2019-
20202001
n_features_in_ : int
20212002
Number of features seen during :term:`fit`.
20222003
@@ -2031,19 +2012,6 @@ class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
20312012
n_iter_ : int
20322013
Number of iterations over the full dataset.
20332014
2034-
iter_offset_ : int
2035-
The number of iteration on data batches that has been performed before.
2036-
2037-
.. deprecated:: 1.1
2038-
`iter_offset_` has been renamed `n_steps_` and will be removed in 1.3.
2039-
2040-
random_state_ : RandomState instance
2041-
RandomState instance that is generated either from a seed, the random
2042-
number generattor or by `np.random`.
2043-
2044-
.. deprecated:: 1.1
2045-
`random_state_` serves internal purpose only and will be removed in 1.3.
2046-
20472015
n_steps_ : int
20482016
Number of mini-batches processed.
20492017
@@ -2100,10 +2068,7 @@ class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
21002068
"max_iter": [Interval(Integral, 0, None, closed="left"), None],
21012069
"fit_algorithm": [StrOptions({"cd", "lars"})],
21022070
"n_jobs": [None, Integral],
2103-
"batch_size": [
2104-
Interval(Integral, 1, None, closed="left"),
2105-
Hidden(StrOptions({"warn"})),
2106-
],
2071+
"batch_size": [Interval(Integral, 1, None, closed="left")],
21072072
"shuffle": ["boolean"],
21082073
"dict_init": [None, np.ndarray],
21092074
"transform_algorithm": [
@@ -2131,7 +2096,7 @@ def __init__(
21312096
max_iter=None,
21322097
fit_algorithm="lars",
21332098
n_jobs=None,
2134-
batch_size="warn",
2099+
batch_size=256,
21352100
shuffle=True,
21362101
dict_init=None,
21372102
transform_algorithm="omp",
@@ -2173,27 +2138,6 @@ def __init__(
21732138
self.max_no_improvement = max_no_improvement
21742139
self.tol = tol
21752140

2176-
@deprecated( # type: ignore
2177-
"The attribute `iter_offset_` is deprecated in 1.1 and will be removed in 1.3."
2178-
)
2179-
@property
2180-
def iter_offset_(self):
2181-
return self.n_iter_
2182-
2183-
@deprecated( # type: ignore
2184-
"The attribute `random_state_` is deprecated in 1.1 and will be removed in 1.3."
2185-
)
2186-
@property
2187-
def random_state_(self):
2188-
return self._random_state
2189-
2190-
@deprecated( # type: ignore
2191-
"The attribute `inner_stats_` is deprecated in 1.1 and will be removed in 1.3."
2192-
)
2193-
@property
2194-
def inner_stats_(self):
2195-
return self._inner_stats
2196-
21972141
def _check_params(self, X):
21982142
# n_components
21992143
self._n_components = self.n_components
@@ -2205,8 +2149,7 @@ def _check_params(self, X):
22052149
self._fit_algorithm = "lasso_" + self.fit_algorithm
22062150

22072151
# batch_size
2208-
if hasattr(self, "_batch_size"):
2209-
self._batch_size = min(self._batch_size, X.shape[0])
2152+
self._batch_size = min(self.batch_size, X.shape[0])
22102153

22112154
def _initialize_dict(self, X, random_state):
22122155
"""Initialization of the dictionary."""
@@ -2245,11 +2188,10 @@ def _update_inner_stats(self, X, code, batch_size, step):
22452188
theta = batch_size**2 + step + 1 - batch_size
22462189
beta = (theta + 1 - batch_size) / (theta + 1)
22472190

2248-
A, B = self._inner_stats
2249-
A *= beta
2250-
A += code.T @ code
2251-
B *= beta
2252-
B += X.T @ code
2191+
self._A *= beta
2192+
self._A += code.T @ code
2193+
self._B *= beta
2194+
self._B += X.T @ code
22532195

22542196
def _minibatch_step(self, X, dictionary, random_state, step):
22552197
"""Perform the update on the dictionary for one minibatch."""
@@ -2277,13 +2219,12 @@ def _minibatch_step(self, X, dictionary, random_state, step):
22772219
self._update_inner_stats(X, code, batch_size, step)
22782220

22792221
# Update dictionary
2280-
A, B = self._inner_stats
22812222
_update_dict(
22822223
dictionary,
22832224
X,
22842225
code,
2285-
A,
2286-
B,
2226+
self._A,
2227+
self._B,
22872228
verbose=self.verbose,
22882229
random_state=random_state,
22892230
positive=self.positive_dict,
@@ -2378,14 +2319,6 @@ def fit(self, X, y=None):
23782319
"""
23792320
self._validate_params()
23802321

2381-
self._batch_size = self.batch_size
2382-
if self.batch_size == "warn":
2383-
warnings.warn(
2384-
"The default value of batch_size will change from 3 to 256 in 1.3.",
2385-
FutureWarning,
2386-
)
2387-
self._batch_size = 3
2388-
23892322
X = self._validate_data(
23902323
X, dtype=[np.float64, np.float32], order="C", copy=False
23912324
)
@@ -2419,10 +2352,10 @@ def fit(self, X, y=None):
24192352
print("[dict_learning]")
24202353

24212354
# Inner stats
2422-
self._inner_stats = (
2423-
np.zeros((self._n_components, self._n_components), dtype=X_train.dtype),
2424-
np.zeros((n_features, self._n_components), dtype=X_train.dtype),
2355+
self._A = np.zeros(
2356+
(self._n_components, self._n_components), dtype=X_train.dtype
24252357
)
2358+
self._B = np.zeros((n_features, self._n_components), dtype=X_train.dtype)
24262359

24272360
if self.max_iter is not None:
24282361

@@ -2483,7 +2416,7 @@ def fit(self, X, y=None):
24832416

24842417
return self
24852418

2486-
def partial_fit(self, X, y=None, iter_offset="deprecated"):
2419+
def partial_fit(self, X, y=None):
24872420
"""Update the model using the data in X as a mini-batch.
24882421
24892422
Parameters
@@ -2495,15 +2428,6 @@ def partial_fit(self, X, y=None, iter_offset="deprecated"):
24952428
y : Ignored
24962429
Not used, present for API consistency by convention.
24972430
2498-
iter_offset : int, default=None
2499-
The number of iteration on data batches that has been
2500-
performed before this call to `partial_fit`. This is optional:
2501-
if no number is passed, the memory of the object is
2502-
used.
2503-
2504-
.. deprecated:: 1.1
2505-
``iter_offset`` will be removed in 1.3.
2506-
25072431
Returns
25082432
-------
25092433
self : object
@@ -2518,27 +2442,17 @@ def partial_fit(self, X, y=None, iter_offset="deprecated"):
25182442
X, dtype=[np.float64, np.float32], order="C", reset=not has_components
25192443
)
25202444

2521-
if iter_offset != "deprecated":
2522-
warnings.warn(
2523-
"'iter_offset' is deprecated in version 1.1 and "
2524-
"will be removed in version 1.3",
2525-
FutureWarning,
2526-
)
2527-
self.n_steps_ = iter_offset
2528-
else:
2529-
self.n_steps_ = getattr(self, "n_steps_", 0)
2530-
25312445
if not has_components:
25322446
# This instance has not been fitted yet (fit or partial_fit)
25332447
self._check_params(X)
25342448
self._random_state = check_random_state(self.random_state)
25352449

25362450
dictionary = self._initialize_dict(X, self._random_state)
25372451

2538-
self._inner_stats = (
2539-
np.zeros((self._n_components, self._n_components), dtype=X.dtype),
2540-
np.zeros((X.shape[1], self._n_components), dtype=X.dtype),
2541-
)
2452+
self.n_steps_ = 0
2453+
2454+
self._A = np.zeros((self._n_components, self._n_components), dtype=X.dtype)
2455+
self._B = np.zeros((X.shape[1], self._n_components), dtype=X.dtype)
25422456
else:
25432457
dictionary = self.components_
25442458

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.