@@ -1021,74 +1021,14 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, *,
1021
1021
"""
1022
1022
X = check_array (X , accept_sparse = ('csr' , 'csc' ),
1023
1023
dtype = [np .float64 , np .float32 ])
1024
- check_non_negative (X , "NMF (input X)" )
1025
- beta_loss = _check_string_param (solver , regularization , beta_loss , init )
1026
1024
1027
- if X . min () == 0 and beta_loss <= 0 :
1028
- raise ValueError ( "When beta_loss <= 0 and X contains zeros, "
1029
- "the solver may diverge. Please add small values to "
1030
- "X, or use a positive beta_loss." )
1025
+ est = NMF ( n_components = n_components , init = init , solver = solver ,
1026
+ beta_loss = beta_loss , tol = tol , max_iter = max_iter ,
1027
+ random_state = random_state , alpha = alpha , l1_ratio = l1_ratio ,
1028
+ verbose = verbose , shuffle = shuffle , regularization = regularization )
1031
1029
1032
- n_samples , n_features = X .shape
1033
- if n_components is None :
1034
- n_components = n_features
1035
-
1036
- if not isinstance (n_components , numbers .Integral ) or n_components <= 0 :
1037
- raise ValueError ("Number of components must be a positive integer;"
1038
- " got (n_components=%r)" % n_components )
1039
- if not isinstance (max_iter , numbers .Integral ) or max_iter < 0 :
1040
- raise ValueError ("Maximum number of iterations must be a positive "
1041
- "integer; got (max_iter=%r)" % max_iter )
1042
- if not isinstance (tol , numbers .Number ) or tol < 0 :
1043
- raise ValueError ("Tolerance for stopping criteria must be "
1044
- "positive; got (tol=%r)" % tol )
1045
-
1046
- # check W and H, or initialize them
1047
- if init == 'custom' and update_H :
1048
- _check_init (H , (n_components , n_features ), "NMF (input H)" )
1049
- _check_init (W , (n_samples , n_components ), "NMF (input W)" )
1050
- if H .dtype != X .dtype or W .dtype != X .dtype :
1051
- raise TypeError ("H and W should have the same dtype as X. Got "
1052
- "H.dtype = {} and W.dtype = {}."
1053
- .format (H .dtype , W .dtype ))
1054
- elif not update_H :
1055
- _check_init (H , (n_components , n_features ), "NMF (input H)" )
1056
- if H .dtype != X .dtype :
1057
- raise TypeError ("H should have the same dtype as X. Got H.dtype = "
1058
- "{}." .format (H .dtype ))
1059
- # 'mu' solver should not be initialized by zeros
1060
- if solver == 'mu' :
1061
- avg = np .sqrt (X .mean () / n_components )
1062
- W = np .full ((n_samples , n_components ), avg , dtype = X .dtype )
1063
- else :
1064
- W = np .zeros ((n_samples , n_components ), dtype = X .dtype )
1065
- else :
1066
- W , H = _initialize_nmf (X , n_components , init = init ,
1067
- random_state = random_state )
1068
-
1069
- l1_reg_W , l1_reg_H , l2_reg_W , l2_reg_H = _compute_regularization (
1070
- alpha , l1_ratio , regularization )
1071
-
1072
- if solver == 'cd' :
1073
- W , H , n_iter = _fit_coordinate_descent (X , W , H , tol , max_iter ,
1074
- l1_reg_W , l1_reg_H ,
1075
- l2_reg_W , l2_reg_H ,
1076
- update_H = update_H ,
1077
- verbose = verbose ,
1078
- shuffle = shuffle ,
1079
- random_state = random_state )
1080
- elif solver == 'mu' :
1081
- W , H , n_iter = _fit_multiplicative_update (X , W , H , beta_loss , max_iter ,
1082
- tol , l1_reg_W , l1_reg_H ,
1083
- l2_reg_W , l2_reg_H , update_H ,
1084
- verbose )
1085
-
1086
- else :
1087
- raise ValueError ("Invalid solver parameter '%s'." % solver )
1088
-
1089
- if n_iter == max_iter and tol > 0 :
1090
- warnings .warn ("Maximum number of iterations %d reached. Increase it to"
1091
- " improve convergence." % max_iter , ConvergenceWarning )
1030
+ with config_context (assume_finite = True ):
1031
+ W , H , n_iter = est ._fit_transform (X , W = W , H = H , update_H = update_H )
1092
1032
1093
1033
return W , H , n_iter
1094
1034
@@ -1281,6 +1221,52 @@ def __init__(self, n_components=None, *, init='warn', solver='cd',
1281
1221
def _more_tags (self ):
1282
1222
return {'requires_positive_X' : True }
1283
1223
1224
+ def _check_params (self , X ):
1225
+ self ._n_components = self .n_components
1226
+ if self ._n_components is None :
1227
+ self ._n_components = X .shape [1 ]
1228
+ if not isinstance (
1229
+ self ._n_components , numbers .Integral
1230
+ ) or self ._n_components <= 0 :
1231
+ raise ValueError ("Number of components must be a positive integer;"
1232
+ " got (n_components=%r)" % self ._n_components )
1233
+ if not isinstance (
1234
+ self .max_iter , numbers .Integral
1235
+ ) or self .max_iter < 0 :
1236
+ raise ValueError ("Maximum number of iterations must be a positive "
1237
+ "integer; got (max_iter=%r)" % self .max_iter )
1238
+ if not isinstance (self .tol , numbers .Number ) or self .tol < 0 :
1239
+ raise ValueError ("Tolerance for stopping criteria must be "
1240
+ "positive; got (tol=%r)" % self .tol )
1241
+ return self
1242
+
1243
+ def _check_w_h (self , X , W , H , update_H ):
1244
+ # check W and H, or initialize them
1245
+ n_samples , n_features = X .shape
1246
+ if self .init == 'custom' and update_H :
1247
+ _check_init (H , (self ._n_components , n_features ), "NMF (input H)" )
1248
+ _check_init (W , (n_samples , self ._n_components ), "NMF (input W)" )
1249
+ if H .dtype != X .dtype or W .dtype != X .dtype :
1250
+ raise TypeError ("H and W should have the same dtype as X. Got "
1251
+ "H.dtype = {} and W.dtype = {}."
1252
+ .format (H .dtype , W .dtype ))
1253
+ elif not update_H :
1254
+ _check_init (H , (self ._n_components , n_features ), "NMF (input H)" )
1255
+ if H .dtype != X .dtype :
1256
+ raise TypeError ("H should have the same dtype as X. Got "
1257
+ "H.dtype = {}." .format (H .dtype ))
1258
+ # 'mu' solver should not be initialized by zeros
1259
+ if self .solver == 'mu' :
1260
+ avg = np .sqrt (X .mean () / self ._n_components )
1261
+ W = np .full ((n_samples , self ._n_components ),
1262
+ avg , dtype = X .dtype )
1263
+ else :
1264
+ W = np .zeros ((n_samples , self ._n_components ), dtype = X .dtype )
1265
+ else :
1266
+ W , H = _initialize_nmf (X , self ._n_components , init = self .init ,
1267
+ random_state = self .random_state )
1268
+ return W , H
1269
+
1284
1270
def fit_transform (self , X , y = None , W = None , H = None ):
1285
1271
"""Learn a NMF model for the data X and returns the transformed data.
1286
1272
@@ -1308,23 +1294,92 @@ def fit_transform(self, X, y=None, W=None, H=None):
1308
1294
dtype = [np .float64 , np .float32 ])
1309
1295
1310
1296
with config_context (assume_finite = True ):
1311
- W , H , n_iter_ = non_negative_factorization (
1312
- X = X , W = W , H = H , n_components = self .n_components , init = self .init ,
1313
- update_H = True , solver = self .solver , beta_loss = self .beta_loss ,
1314
- tol = self .tol , max_iter = self .max_iter , alpha = self .alpha ,
1315
- l1_ratio = self .l1_ratio , regularization = self .regularization ,
1316
- random_state = self .random_state , verbose = self .verbose ,
1317
- shuffle = self .shuffle )
1318
-
1319
- self .reconstruction_err_ = _beta_divergence (X , W , H , self .beta_loss ,
1297
+ W , H , n_iter = self ._fit_transform (X , W = W , H = H )
1298
+
1299
+ self .reconstruction_err_ = _beta_divergence (X , W , H , self ._beta_loss ,
1320
1300
square_root = True )
1321
1301
1322
1302
self .n_components_ = H .shape [0 ]
1323
1303
self .components_ = H
1324
- self .n_iter_ = n_iter_
1304
+ self .n_iter_ = n_iter
1325
1305
1326
1306
return W
1327
1307
1308
+ def _fit_transform (self , X , y = None , W = None , H = None , update_H = True ):
1309
+ """Learn a NMF model for the data X and returns the transformed data.
1310
+
1311
+ Parameters
1312
+ ----------
1313
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
1314
+ Data matrix to be decomposed
1315
+
1316
+ y : Ignored
1317
+
1318
+ W : array-like of shape (n_samples, n_components)
1319
+ If init='custom', it is used as initial guess for the solution.
1320
+
1321
+ H : array-like of shape (n_components, n_features)
1322
+ If init='custom', it is used as initial guess for the solution.
1323
+ If update_H=False, it is used as a constant, to solve for W only.
1324
+
1325
+ update_H : bool, default=True
1326
+ If True, both W and H will be estimated from initial guesses,
1327
+ this corresponds to a call to the 'fit_transform' method.
1328
+ If False, only W will be estimated, this corresponds to a call
1329
+ to the 'transform' method.
1330
+
1331
+ Returns
1332
+ -------
1333
+ W : ndarray of shape (n_samples, n_components)
1334
+ Transformed data.
1335
+
1336
+ H : ndarray of shape (n_components, n_features)
1337
+ Factorization matrix, sometimes called 'dictionary'.
1338
+
1339
+ n_iter_ : int
1340
+ Actual number of iterations.
1341
+ """
1342
+ check_non_negative (X , "NMF (input X)" )
1343
+ self ._beta_loss = _check_string_param (self .solver , self .regularization ,
1344
+ self .beta_loss , self .init )
1345
+
1346
+ if X .min () == 0 and self ._beta_loss <= 0 :
1347
+ raise ValueError ("When beta_loss <= 0 and X contains zeros, "
1348
+ "the solver may diverge. Please add small values "
1349
+ "to X, or use a positive beta_loss." )
1350
+
1351
+ n_samples , n_features = X .shape
1352
+
1353
+ # check parameters
1354
+ self ._check_params (X )
1355
+
1356
+ # initialize or check W and H
1357
+ W , H = self ._check_w_h (X , W , H , update_H )
1358
+
1359
+ l1_reg_W , l1_reg_H , l2_reg_W , l2_reg_H = _compute_regularization (
1360
+ self .alpha , self .l1_ratio , self .regularization )
1361
+
1362
+ if self .solver == 'cd' :
1363
+ W , H , n_iter = _fit_coordinate_descent (
1364
+ X , W , H , self .tol , self .max_iter , l1_reg_W , l1_reg_H ,
1365
+ l2_reg_W , l2_reg_H , update_H = update_H ,
1366
+ verbose = self .verbose , shuffle = self .shuffle ,
1367
+ random_state = self .random_state )
1368
+ elif self .solver == 'mu' :
1369
+ W , H , n_iter = _fit_multiplicative_update (
1370
+ X , W , H , self ._beta_loss , self .max_iter , self .tol ,
1371
+ l1_reg_W , l1_reg_H , l2_reg_W , l2_reg_H ,
1372
+ update_H = update_H , verbose = self .verbose )
1373
+ else :
1374
+ raise ValueError ("Invalid solver parameter '%s'." % self .solver )
1375
+
1376
+ if n_iter == self .max_iter and self .tol > 0 :
1377
+ warnings .warn ("Maximum number of iterations %d reached. Increase "
1378
+ "it to improve convergence." % self .max_iter ,
1379
+ ConvergenceWarning )
1380
+
1381
+ return W , H , n_iter
1382
+
1328
1383
def fit (self , X , y = None , ** params ):
1329
1384
"""Learn a NMF model for the data X.
1330
1385
@@ -1361,15 +1416,7 @@ def transform(self, X):
1361
1416
reset = False )
1362
1417
1363
1418
with config_context (assume_finite = True ):
1364
- W , _ , n_iter_ = non_negative_factorization (
1365
- X = X , W = None , H = self .components_ ,
1366
- n_components = self .n_components_ ,
1367
- init = self .init , update_H = False , solver = self .solver ,
1368
- beta_loss = self .beta_loss , tol = self .tol , max_iter = self .max_iter ,
1369
- alpha = self .alpha , l1_ratio = self .l1_ratio ,
1370
- regularization = self .regularization ,
1371
- random_state = self .random_state ,
1372
- verbose = self .verbose , shuffle = self .shuffle )
1419
+ W , * _ = self ._fit_transform (X , H = self .components_ , update_H = False )
1373
1420
1374
1421
return W
1375
1422
0 commit comments