Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 15fd026

Browse filesBrowse files
authored
RFC Make non_negative_factorization call NMF instead of the opposite (#19607)
1 parent 579e7de commit 15fd026
Copy full SHA for 15fd026

File tree

1 file changed

+132
-85
lines changed
Filter options

1 file changed

+132
-85
lines changed

‎sklearn/decomposition/_nmf.py

Copy file name to clipboardExpand all lines: sklearn/decomposition/_nmf.py
+132-85Lines changed: 132 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,74 +1021,14 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, *,
10211021
"""
10221022
X = check_array(X, accept_sparse=('csr', 'csc'),
10231023
dtype=[np.float64, np.float32])
1024-
check_non_negative(X, "NMF (input X)")
1025-
beta_loss = _check_string_param(solver, regularization, beta_loss, init)
10261024

1027-
if X.min() == 0 and beta_loss <= 0:
1028-
raise ValueError("When beta_loss <= 0 and X contains zeros, "
1029-
"the solver may diverge. Please add small values to "
1030-
"X, or use a positive beta_loss.")
1025+
est = NMF(n_components=n_components, init=init, solver=solver,
1026+
beta_loss=beta_loss, tol=tol, max_iter=max_iter,
1027+
random_state=random_state, alpha=alpha, l1_ratio=l1_ratio,
1028+
verbose=verbose, shuffle=shuffle, regularization=regularization)
10311029

1032-
n_samples, n_features = X.shape
1033-
if n_components is None:
1034-
n_components = n_features
1035-
1036-
if not isinstance(n_components, numbers.Integral) or n_components <= 0:
1037-
raise ValueError("Number of components must be a positive integer;"
1038-
" got (n_components=%r)" % n_components)
1039-
if not isinstance(max_iter, numbers.Integral) or max_iter < 0:
1040-
raise ValueError("Maximum number of iterations must be a positive "
1041-
"integer; got (max_iter=%r)" % max_iter)
1042-
if not isinstance(tol, numbers.Number) or tol < 0:
1043-
raise ValueError("Tolerance for stopping criteria must be "
1044-
"positive; got (tol=%r)" % tol)
1045-
1046-
# check W and H, or initialize them
1047-
if init == 'custom' and update_H:
1048-
_check_init(H, (n_components, n_features), "NMF (input H)")
1049-
_check_init(W, (n_samples, n_components), "NMF (input W)")
1050-
if H.dtype != X.dtype or W.dtype != X.dtype:
1051-
raise TypeError("H and W should have the same dtype as X. Got "
1052-
"H.dtype = {} and W.dtype = {}."
1053-
.format(H.dtype, W.dtype))
1054-
elif not update_H:
1055-
_check_init(H, (n_components, n_features), "NMF (input H)")
1056-
if H.dtype != X.dtype:
1057-
raise TypeError("H should have the same dtype as X. Got H.dtype = "
1058-
"{}.".format(H.dtype))
1059-
# 'mu' solver should not be initialized by zeros
1060-
if solver == 'mu':
1061-
avg = np.sqrt(X.mean() / n_components)
1062-
W = np.full((n_samples, n_components), avg, dtype=X.dtype)
1063-
else:
1064-
W = np.zeros((n_samples, n_components), dtype=X.dtype)
1065-
else:
1066-
W, H = _initialize_nmf(X, n_components, init=init,
1067-
random_state=random_state)
1068-
1069-
l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H = _compute_regularization(
1070-
alpha, l1_ratio, regularization)
1071-
1072-
if solver == 'cd':
1073-
W, H, n_iter = _fit_coordinate_descent(X, W, H, tol, max_iter,
1074-
l1_reg_W, l1_reg_H,
1075-
l2_reg_W, l2_reg_H,
1076-
update_H=update_H,
1077-
verbose=verbose,
1078-
shuffle=shuffle,
1079-
random_state=random_state)
1080-
elif solver == 'mu':
1081-
W, H, n_iter = _fit_multiplicative_update(X, W, H, beta_loss, max_iter,
1082-
tol, l1_reg_W, l1_reg_H,
1083-
l2_reg_W, l2_reg_H, update_H,
1084-
verbose)
1085-
1086-
else:
1087-
raise ValueError("Invalid solver parameter '%s'." % solver)
1088-
1089-
if n_iter == max_iter and tol > 0:
1090-
warnings.warn("Maximum number of iterations %d reached. Increase it to"
1091-
" improve convergence." % max_iter, ConvergenceWarning)
1030+
with config_context(assume_finite=True):
1031+
W, H, n_iter = est._fit_transform(X, W=W, H=H, update_H=update_H)
10921032

10931033
return W, H, n_iter
10941034

@@ -1281,6 +1221,52 @@ def __init__(self, n_components=None, *, init='warn', solver='cd',
12811221
def _more_tags(self):
12821222
return {'requires_positive_X': True}
12831223

1224+
def _check_params(self, X):
1225+
self._n_components = self.n_components
1226+
if self._n_components is None:
1227+
self._n_components = X.shape[1]
1228+
if not isinstance(
1229+
self._n_components, numbers.Integral
1230+
) or self._n_components <= 0:
1231+
raise ValueError("Number of components must be a positive integer;"
1232+
" got (n_components=%r)" % self._n_components)
1233+
if not isinstance(
1234+
self.max_iter, numbers.Integral
1235+
) or self.max_iter < 0:
1236+
raise ValueError("Maximum number of iterations must be a positive "
1237+
"integer; got (max_iter=%r)" % self.max_iter)
1238+
if not isinstance(self.tol, numbers.Number) or self.tol < 0:
1239+
raise ValueError("Tolerance for stopping criteria must be "
1240+
"positive; got (tol=%r)" % self.tol)
1241+
return self
1242+
1243+
def _check_w_h(self, X, W, H, update_H):
1244+
# check W and H, or initialize them
1245+
n_samples, n_features = X.shape
1246+
if self.init == 'custom' and update_H:
1247+
_check_init(H, (self._n_components, n_features), "NMF (input H)")
1248+
_check_init(W, (n_samples, self._n_components), "NMF (input W)")
1249+
if H.dtype != X.dtype or W.dtype != X.dtype:
1250+
raise TypeError("H and W should have the same dtype as X. Got "
1251+
"H.dtype = {} and W.dtype = {}."
1252+
.format(H.dtype, W.dtype))
1253+
elif not update_H:
1254+
_check_init(H, (self._n_components, n_features), "NMF (input H)")
1255+
if H.dtype != X.dtype:
1256+
raise TypeError("H should have the same dtype as X. Got "
1257+
"H.dtype = {}.".format(H.dtype))
1258+
# 'mu' solver should not be initialized by zeros
1259+
if self.solver == 'mu':
1260+
avg = np.sqrt(X.mean() / self._n_components)
1261+
W = np.full((n_samples, self._n_components),
1262+
avg, dtype=X.dtype)
1263+
else:
1264+
W = np.zeros((n_samples, self._n_components), dtype=X.dtype)
1265+
else:
1266+
W, H = _initialize_nmf(X, self._n_components, init=self.init,
1267+
random_state=self.random_state)
1268+
return W, H
1269+
12841270
def fit_transform(self, X, y=None, W=None, H=None):
12851271
"""Learn a NMF model for the data X and returns the transformed data.
12861272
@@ -1308,23 +1294,92 @@ def fit_transform(self, X, y=None, W=None, H=None):
13081294
dtype=[np.float64, np.float32])
13091295

13101296
with config_context(assume_finite=True):
1311-
W, H, n_iter_ = non_negative_factorization(
1312-
X=X, W=W, H=H, n_components=self.n_components, init=self.init,
1313-
update_H=True, solver=self.solver, beta_loss=self.beta_loss,
1314-
tol=self.tol, max_iter=self.max_iter, alpha=self.alpha,
1315-
l1_ratio=self.l1_ratio, regularization=self.regularization,
1316-
random_state=self.random_state, verbose=self.verbose,
1317-
shuffle=self.shuffle)
1318-
1319-
self.reconstruction_err_ = _beta_divergence(X, W, H, self.beta_loss,
1297+
W, H, n_iter = self._fit_transform(X, W=W, H=H)
1298+
1299+
self.reconstruction_err_ = _beta_divergence(X, W, H, self._beta_loss,
13201300
square_root=True)
13211301

13221302
self.n_components_ = H.shape[0]
13231303
self.components_ = H
1324-
self.n_iter_ = n_iter_
1304+
self.n_iter_ = n_iter
13251305

13261306
return W
13271307

1308+
def _fit_transform(self, X, y=None, W=None, H=None, update_H=True):
1309+
"""Learn a NMF model for the data X and returns the transformed data.
1310+
1311+
Parameters
1312+
----------
1313+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
1314+
Data matrix to be decomposed
1315+
1316+
y : Ignored
1317+
1318+
W : array-like of shape (n_samples, n_components)
1319+
If init='custom', it is used as initial guess for the solution.
1320+
1321+
H : array-like of shape (n_components, n_features)
1322+
If init='custom', it is used as initial guess for the solution.
1323+
If update_H=False, it is used as a constant, to solve for W only.
1324+
1325+
update_H : bool, default=True
1326+
If True, both W and H will be estimated from initial guesses,
1327+
this corresponds to a call to the 'fit_transform' method.
1328+
If False, only W will be estimated, this corresponds to a call
1329+
to the 'transform' method.
1330+
1331+
Returns
1332+
-------
1333+
W : ndarray of shape (n_samples, n_components)
1334+
Transformed data.
1335+
1336+
H : ndarray of shape (n_components, n_features)
1337+
Factorization matrix, sometimes called 'dictionary'.
1338+
1339+
n_iter_ : int
1340+
Actual number of iterations.
1341+
"""
1342+
check_non_negative(X, "NMF (input X)")
1343+
self._beta_loss = _check_string_param(self.solver, self.regularization,
1344+
self.beta_loss, self.init)
1345+
1346+
if X.min() == 0 and self._beta_loss <= 0:
1347+
raise ValueError("When beta_loss <= 0 and X contains zeros, "
1348+
"the solver may diverge. Please add small values "
1349+
"to X, or use a positive beta_loss.")
1350+
1351+
n_samples, n_features = X.shape
1352+
1353+
# check parameters
1354+
self._check_params(X)
1355+
1356+
# initialize or check W and H
1357+
W, H = self._check_w_h(X, W, H, update_H)
1358+
1359+
l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H = _compute_regularization(
1360+
self.alpha, self.l1_ratio, self.regularization)
1361+
1362+
if self.solver == 'cd':
1363+
W, H, n_iter = _fit_coordinate_descent(
1364+
X, W, H, self.tol, self.max_iter, l1_reg_W, l1_reg_H,
1365+
l2_reg_W, l2_reg_H, update_H=update_H,
1366+
verbose=self.verbose, shuffle=self.shuffle,
1367+
random_state=self.random_state)
1368+
elif self.solver == 'mu':
1369+
W, H, n_iter = _fit_multiplicative_update(
1370+
X, W, H, self._beta_loss, self.max_iter, self.tol,
1371+
l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H,
1372+
update_H=update_H, verbose=self.verbose)
1373+
else:
1374+
raise ValueError("Invalid solver parameter '%s'." % self.solver)
1375+
1376+
if n_iter == self.max_iter and self.tol > 0:
1377+
warnings.warn("Maximum number of iterations %d reached. Increase "
1378+
"it to improve convergence." % self.max_iter,
1379+
ConvergenceWarning)
1380+
1381+
return W, H, n_iter
1382+
13281383
def fit(self, X, y=None, **params):
13291384
"""Learn a NMF model for the data X.
13301385
@@ -1361,15 +1416,7 @@ def transform(self, X):
13611416
reset=False)
13621417

13631418
with config_context(assume_finite=True):
1364-
W, _, n_iter_ = non_negative_factorization(
1365-
X=X, W=None, H=self.components_,
1366-
n_components=self.n_components_,
1367-
init=self.init, update_H=False, solver=self.solver,
1368-
beta_loss=self.beta_loss, tol=self.tol, max_iter=self.max_iter,
1369-
alpha=self.alpha, l1_ratio=self.l1_ratio,
1370-
regularization=self.regularization,
1371-
random_state=self.random_state,
1372-
verbose=self.verbose, shuffle=self.shuffle)
1419+
W, *_ = self._fit_transform(X, H=self.components_, update_H=False)
13731420

13741421
return W
13751422

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.