27
27
from .preprocessing import binarize
28
28
from .preprocessing import LabelBinarizer
29
29
from .preprocessing import label_binarize
30
- from .utils import check_X_y , check_array , deprecated
30
+ from .utils import deprecated
31
31
from .utils .extmath import safe_sparse_dot
32
32
from .utils .multiclass import _check_partial_fit_first_call
33
- from .utils .validation import check_is_fitted , check_non_negative , column_or_1d
33
+ from .utils .validation import check_is_fitted , check_non_negative
34
34
from .utils .validation import _check_sample_weight
35
35
from .utils .validation import _deprecate_positional_args
36
36
@@ -55,7 +55,10 @@ def _joint_log_likelihood(self, X):
55
55
56
56
@abstractmethod
57
57
def _check_X (self , X ):
58
- """To be overridden in subclasses with the actual checks."""
58
+ """To be overridden in subclasses with the actual checks.
59
+
60
+ Only used in predict* methods.
61
+ """
59
62
60
63
def predict (self , X ):
61
64
"""
@@ -214,12 +217,12 @@ def fit(self, X, y, sample_weight=None):
214
217
self : object
215
218
"""
216
219
X , y = self ._validate_data (X , y )
217
- y = column_or_1d (y , warn = True )
218
220
return self ._partial_fit (X , y , np .unique (y ), _refit = True ,
219
221
sample_weight = sample_weight )
220
222
221
223
def _check_X (self , X ):
222
- return check_array (X )
224
+ """Validate X, used only in predict* methods."""
225
+ return self ._validate_data (X , reset = False )
223
226
224
227
@staticmethod
225
228
def _update_mean_variance (n_past , mu , var , X , sample_weight = None ):
@@ -367,7 +370,11 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
367
370
-------
368
371
self : object
369
372
"""
370
- X , y = check_X_y (X , y )
373
+ if _refit :
374
+ self .classes_ = None
375
+
376
+ first_call = _check_partial_fit_first_call (self , classes )
377
+ X , y = self ._validate_data (X , y , reset = first_call )
371
378
if sample_weight is not None :
372
379
sample_weight = _check_sample_weight (sample_weight , X )
373
380
@@ -377,10 +384,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
377
384
# deviation of the largest dimension.
378
385
self .epsilon_ = self .var_smoothing * np .var (X , axis = 0 ).max ()
379
386
380
- if _refit :
381
- self .classes_ = None
382
-
383
- if _check_partial_fit_first_call (self , classes ):
387
+ if first_call :
384
388
# This is the first call to partial_fit:
385
389
# initialize various cumulative counters
386
390
n_features = X .shape [1 ]
@@ -488,10 +492,12 @@ class _BaseDiscreteNB(_BaseNB):
488
492
"""
489
493
490
494
def _check_X (self , X ):
491
- return check_array (X , accept_sparse = 'csr' )
495
+ """Validate X, used only in predict* methods."""
496
+ return self ._validate_data (X , accept_sparse = 'csr' , reset = False )
492
497
493
- def _check_X_y (self , X , y ):
494
- return self ._validate_data (X , y , accept_sparse = 'csr' )
498
+ def _check_X_y (self , X , y , reset = True ):
499
+ """Validate X and y in fit methods."""
500
+ return self ._validate_data (X , y , accept_sparse = 'csr' , reset = reset )
495
501
496
502
def _update_class_log_prior (self , class_prior = None ):
497
503
n_classes = len (self .classes_ )
@@ -518,7 +524,7 @@ def _check_alpha(self):
518
524
raise ValueError ('Smoothing parameter alpha = %.1e. '
519
525
'alpha should be > 0.' % np .min (self .alpha ))
520
526
if isinstance (self .alpha , np .ndarray ):
521
- if not self .alpha .shape [0 ] == self .n_features_ :
527
+ if not self .alpha .shape [0 ] == self .n_features_in_ :
522
528
raise ValueError ("alpha should be a scalar or a numpy array "
523
529
"with shape [n_features]" )
524
530
if np .min (self .alpha ) < _ALPHA_MIN :
@@ -563,18 +569,15 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
563
569
-------
564
570
self : object
565
571
"""
566
- X , y = self ._check_X_y (X , y )
572
+ first_call = not hasattr (self , "classes_" )
573
+ X , y = self ._check_X_y (X , y , reset = first_call )
567
574
_ , n_features = X .shape
568
575
569
576
if _check_partial_fit_first_call (self , classes ):
570
577
# This is the first call to partial_fit:
571
578
# initialize various cumulative counters
572
579
n_classes = len (classes )
573
580
self ._init_counters (n_classes , n_features )
574
- self .n_features_ = n_features
575
- elif n_features != self .n_features_ :
576
- msg = "Number of features %d does not match previous data %d."
577
- raise ValueError (msg % (n_features , self .n_features_ ))
578
581
579
582
Y = label_binarize (y , classes = self .classes_ )
580
583
if Y .shape [1 ] == 1 :
@@ -631,7 +634,6 @@ def fit(self, X, y, sample_weight=None):
631
634
"""
632
635
X , y = self ._check_X_y (X , y )
633
636
_ , n_features = X .shape
634
- self .n_features_ = n_features
635
637
636
638
labelbin = LabelBinarizer ()
637
639
Y = labelbin .fit_transform (y )
@@ -687,6 +689,16 @@ def intercept_(self):
687
689
def _more_tags (self ):
688
690
return {'poor_score' : True }
689
691
692
+ # TODO: Remove in 1.2
693
+ # mypy error: Decorated property not supported
694
+ @deprecated ( # type: ignore
695
+ "Attribute n_features_ was deprecated in version 1.0 and will be "
696
+ "removed in 1.2. Use 'n_features_in_' instead."
697
+ )
698
+ @property
699
+ def n_features_ (self ):
700
+ return self .n_features_in_
701
+
690
702
691
703
class MultinomialNB (_BaseDiscreteNB ):
692
704
"""
@@ -753,6 +765,10 @@ class MultinomialNB(_BaseDiscreteNB):
753
765
n_features_ : int
754
766
Number of features of each sample.
755
767
768
+ .. deprecated:: 1.0
769
+ Attribute `n_features_` was deprecated in version 1.0 and will be
770
+ removed in 1.2. Use `n_features_in_` instead.
771
+
756
772
Examples
757
773
--------
758
774
>>> import numpy as np
@@ -879,6 +895,10 @@ class ComplementNB(_BaseDiscreteNB):
879
895
n_features_ : int
880
896
Number of features of each sample.
881
897
898
+ .. deprecated:: 1.0
899
+ Attribute `n_features_` was deprecated in version 1.0 and will be
900
+ removed in 1.2. Use `n_features_in_` instead.
901
+
882
902
Examples
883
903
--------
884
904
>>> import numpy as np
@@ -996,6 +1016,10 @@ class BernoulliNB(_BaseDiscreteNB):
996
1016
n_features_ : int
997
1017
Number of features of each sample.
998
1018
1019
+ .. deprecated:: 1.0
1020
+ Attribute `n_features_` was deprecated in version 1.0 and will be
1021
+ removed in 1.2. Use `n_features_in_` instead.
1022
+
999
1023
Examples
1000
1024
--------
1001
1025
>>> import numpy as np
@@ -1032,13 +1056,14 @@ def __init__(self, *, alpha=1.0, binarize=.0, fit_prior=True,
1032
1056
self .class_prior = class_prior
1033
1057
1034
1058
def _check_X (self , X ):
1059
+ """Validate X, used only in predict* methods."""
1035
1060
X = super ()._check_X (X )
1036
1061
if self .binarize is not None :
1037
1062
X = binarize (X , threshold = self .binarize )
1038
1063
return X
1039
1064
1040
- def _check_X_y (self , X , y ):
1041
- X , y = super ()._check_X_y (X , y )
1065
+ def _check_X_y (self , X , y , reset = True ):
1066
+ X , y = super ()._check_X_y (X , y , reset = reset )
1042
1067
if self .binarize is not None :
1043
1068
X = binarize (X , threshold = self .binarize )
1044
1069
return X , y
@@ -1133,6 +1158,10 @@ class CategoricalNB(_BaseDiscreteNB):
1133
1158
n_features_ : int
1134
1159
Number of features of each sample.
1135
1160
1161
+ .. deprecated:: 1.0
1162
+ Attribute `n_features_` was deprecated in version 1.0 and will be
1163
+ removed in 1.2. Use `n_features_in_` instead.
1164
+
1136
1165
n_categories_ : ndarray of shape (n_features,), dtype=np.int64
1137
1166
Number of categories for each feature. This value is
1138
1167
inferred from the data or set by the minimum number of categories.
@@ -1235,14 +1264,15 @@ def _more_tags(self):
1235
1264
return {'requires_positive_X' : True }
1236
1265
1237
1266
def _check_X (self , X ):
1238
- X = check_array (X , dtype = 'int' , accept_sparse = False ,
1239
- force_all_finite = True )
1267
+ """Validate X, used only in predict* methods."""
1268
+ X = self ._validate_data (X , dtype = 'int' , accept_sparse = False ,
1269
+ force_all_finite = True , reset = False )
1240
1270
check_non_negative (X , "CategoricalNB (input X)" )
1241
1271
return X
1242
1272
1243
- def _check_X_y (self , X , y ):
1273
+ def _check_X_y (self , X , y , reset = True ):
1244
1274
X , y = self ._validate_data (X , y , dtype = 'int' , accept_sparse = False ,
1245
- force_all_finite = True )
1275
+ force_all_finite = True , reset = reset )
1246
1276
check_non_negative (X , "CategoricalNB (input X)" )
1247
1277
return X , y
1248
1278
@@ -1297,7 +1327,7 @@ def _update_cat_count(X_feature, Y, cat_count, n_classes):
1297
1327
self .class_count_ += Y .sum (axis = 0 )
1298
1328
self .n_categories_ = self ._validate_n_categories (
1299
1329
X , self .min_categories )
1300
- for i in range (self .n_features_ ):
1330
+ for i in range (self .n_features_in_ ):
1301
1331
X_feature = X [:, i ]
1302
1332
self .category_count_ [i ] = _update_cat_count_dims (
1303
1333
self .category_count_ [i ], self .n_categories_ [i ] - 1 )
@@ -1307,7 +1337,7 @@ def _update_cat_count(X_feature, Y, cat_count, n_classes):
1307
1337
1308
1338
def _update_feature_log_prob (self , alpha ):
1309
1339
feature_log_prob = []
1310
- for i in range (self .n_features_ ):
1340
+ for i in range (self .n_features_in_ ):
1311
1341
smoothed_cat_count = self .category_count_ [i ] + alpha
1312
1342
smoothed_class_count = smoothed_cat_count .sum (axis = 1 )
1313
1343
feature_log_prob .append (
@@ -1316,11 +1346,9 @@ def _update_feature_log_prob(self, alpha):
1316
1346
self .feature_log_prob_ = feature_log_prob
1317
1347
1318
1348
def _joint_log_likelihood (self , X ):
1319
- if not X .shape [1 ] == self .n_features_ :
1320
- raise ValueError ("Expected input with %d features, got %d instead"
1321
- % (self .n_features_ , X .shape [1 ]))
1349
+ self ._check_n_features (X , reset = False )
1322
1350
jll = np .zeros ((X .shape [0 ], self .class_count_ .shape [0 ]))
1323
- for i in range (self .n_features_ ):
1351
+ for i in range (self .n_features_in_ ):
1324
1352
indices = X [:, i ]
1325
1353
jll += self .feature_log_prob_ [i ][:, indices ].T
1326
1354
total_ll = jll + self .class_log_prior_
0 commit comments