Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5c24622

Browse filesBrowse files
authored
ENH Adds n_features_in_ to naive_bayes (#19485)
1 parent 26c5530 commit 5c24622
Copy full SHA for 5c24622

File tree

3 files changed

+91
-77
lines changed
Filter options

3 files changed

+91
-77
lines changed

‎sklearn/naive_bayes.py

Copy file name to clipboardExpand all lines: sklearn/naive_bayes.py
+60-32Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
from .preprocessing import binarize
2828
from .preprocessing import LabelBinarizer
2929
from .preprocessing import label_binarize
30-
from .utils import check_X_y, check_array, deprecated
30+
from .utils import deprecated
3131
from .utils.extmath import safe_sparse_dot
3232
from .utils.multiclass import _check_partial_fit_first_call
33-
from .utils.validation import check_is_fitted, check_non_negative, column_or_1d
33+
from .utils.validation import check_is_fitted, check_non_negative
3434
from .utils.validation import _check_sample_weight
3535
from .utils.validation import _deprecate_positional_args
3636

@@ -55,7 +55,10 @@ def _joint_log_likelihood(self, X):
5555

5656
@abstractmethod
5757
def _check_X(self, X):
58-
"""To be overridden in subclasses with the actual checks."""
58+
"""To be overridden in subclasses with the actual checks.
59+
60+
Only used in predict* methods.
61+
"""
5962

6063
def predict(self, X):
6164
"""
@@ -214,12 +217,12 @@ def fit(self, X, y, sample_weight=None):
214217
self : object
215218
"""
216219
X, y = self._validate_data(X, y)
217-
y = column_or_1d(y, warn=True)
218220
return self._partial_fit(X, y, np.unique(y), _refit=True,
219221
sample_weight=sample_weight)
220222

221223
def _check_X(self, X):
222-
return check_array(X)
224+
"""Validate X, used only in predict* methods."""
225+
return self._validate_data(X, reset=False)
223226

224227
@staticmethod
225228
def _update_mean_variance(n_past, mu, var, X, sample_weight=None):
@@ -367,7 +370,11 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
367370
-------
368371
self : object
369372
"""
370-
X, y = check_X_y(X, y)
373+
if _refit:
374+
self.classes_ = None
375+
376+
first_call = _check_partial_fit_first_call(self, classes)
377+
X, y = self._validate_data(X, y, reset=first_call)
371378
if sample_weight is not None:
372379
sample_weight = _check_sample_weight(sample_weight, X)
373380

@@ -377,10 +384,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
377384
# deviation of the largest dimension.
378385
self.epsilon_ = self.var_smoothing * np.var(X, axis=0).max()
379386

380-
if _refit:
381-
self.classes_ = None
382-
383-
if _check_partial_fit_first_call(self, classes):
387+
if first_call:
384388
# This is the first call to partial_fit:
385389
# initialize various cumulative counters
386390
n_features = X.shape[1]
@@ -488,10 +492,12 @@ class _BaseDiscreteNB(_BaseNB):
488492
"""
489493

490494
def _check_X(self, X):
491-
return check_array(X, accept_sparse='csr')
495+
"""Validate X, used only in predict* methods."""
496+
return self._validate_data(X, accept_sparse='csr', reset=False)
492497

493-
def _check_X_y(self, X, y):
494-
return self._validate_data(X, y, accept_sparse='csr')
498+
def _check_X_y(self, X, y, reset=True):
499+
"""Validate X and y in fit methods."""
500+
return self._validate_data(X, y, accept_sparse='csr', reset=reset)
495501

496502
def _update_class_log_prior(self, class_prior=None):
497503
n_classes = len(self.classes_)
@@ -518,7 +524,7 @@ def _check_alpha(self):
518524
raise ValueError('Smoothing parameter alpha = %.1e. '
519525
'alpha should be > 0.' % np.min(self.alpha))
520526
if isinstance(self.alpha, np.ndarray):
521-
if not self.alpha.shape[0] == self.n_features_:
527+
if not self.alpha.shape[0] == self.n_features_in_:
522528
raise ValueError("alpha should be a scalar or a numpy array "
523529
"with shape [n_features]")
524530
if np.min(self.alpha) < _ALPHA_MIN:
@@ -563,18 +569,15 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
563569
-------
564570
self : object
565571
"""
566-
X, y = self._check_X_y(X, y)
572+
first_call = not hasattr(self, "classes_")
573+
X, y = self._check_X_y(X, y, reset=first_call)
567574
_, n_features = X.shape
568575

569576
if _check_partial_fit_first_call(self, classes):
570577
# This is the first call to partial_fit:
571578
# initialize various cumulative counters
572579
n_classes = len(classes)
573580
self._init_counters(n_classes, n_features)
574-
self.n_features_ = n_features
575-
elif n_features != self.n_features_:
576-
msg = "Number of features %d does not match previous data %d."
577-
raise ValueError(msg % (n_features, self.n_features_))
578581

579582
Y = label_binarize(y, classes=self.classes_)
580583
if Y.shape[1] == 1:
@@ -631,7 +634,6 @@ def fit(self, X, y, sample_weight=None):
631634
"""
632635
X, y = self._check_X_y(X, y)
633636
_, n_features = X.shape
634-
self.n_features_ = n_features
635637

636638
labelbin = LabelBinarizer()
637639
Y = labelbin.fit_transform(y)
@@ -687,6 +689,16 @@ def intercept_(self):
687689
def _more_tags(self):
688690
return {'poor_score': True}
689691

692+
# TODO: Remove in 1.2
693+
# mypy error: Decorated property not supported
694+
@deprecated( # type: ignore
695+
"Attribute n_features_ was deprecated in version 1.0 and will be "
696+
"removed in 1.2. Use 'n_features_in_' instead."
697+
)
698+
@property
699+
def n_features_(self):
700+
return self.n_features_in_
701+
690702

691703
class MultinomialNB(_BaseDiscreteNB):
692704
"""
@@ -753,6 +765,10 @@ class MultinomialNB(_BaseDiscreteNB):
753765
n_features_ : int
754766
Number of features of each sample.
755767
768+
.. deprecated:: 1.0
769+
Attribute `n_features_` was deprecated in version 1.0 and will be
770+
removed in 1.2. Use `n_features_in_` instead.
771+
756772
Examples
757773
--------
758774
>>> import numpy as np
@@ -879,6 +895,10 @@ class ComplementNB(_BaseDiscreteNB):
879895
n_features_ : int
880896
Number of features of each sample.
881897
898+
.. deprecated:: 1.0
899+
Attribute `n_features_` was deprecated in version 1.0 and will be
900+
removed in 1.2. Use `n_features_in_` instead.
901+
882902
Examples
883903
--------
884904
>>> import numpy as np
@@ -996,6 +1016,10 @@ class BernoulliNB(_BaseDiscreteNB):
9961016
n_features_ : int
9971017
Number of features of each sample.
9981018
1019+
.. deprecated:: 1.0
1020+
Attribute `n_features_` was deprecated in version 1.0 and will be
1021+
removed in 1.2. Use `n_features_in_` instead.
1022+
9991023
Examples
10001024
--------
10011025
>>> import numpy as np
@@ -1032,13 +1056,14 @@ def __init__(self, *, alpha=1.0, binarize=.0, fit_prior=True,
10321056
self.class_prior = class_prior
10331057

10341058
def _check_X(self, X):
1059+
"""Validate X, used only in predict* methods."""
10351060
X = super()._check_X(X)
10361061
if self.binarize is not None:
10371062
X = binarize(X, threshold=self.binarize)
10381063
return X
10391064

1040-
def _check_X_y(self, X, y):
1041-
X, y = super()._check_X_y(X, y)
1065+
def _check_X_y(self, X, y, reset=True):
1066+
X, y = super()._check_X_y(X, y, reset=reset)
10421067
if self.binarize is not None:
10431068
X = binarize(X, threshold=self.binarize)
10441069
return X, y
@@ -1133,6 +1158,10 @@ class CategoricalNB(_BaseDiscreteNB):
11331158
n_features_ : int
11341159
Number of features of each sample.
11351160
1161+
.. deprecated:: 1.0
1162+
Attribute `n_features_` was deprecated in version 1.0 and will be
1163+
removed in 1.2. Use `n_features_in_` instead.
1164+
11361165
n_categories_ : ndarray of shape (n_features,), dtype=np.int64
11371166
Number of categories for each feature. This value is
11381167
inferred from the data or set by the minimum number of categories.
@@ -1235,14 +1264,15 @@ def _more_tags(self):
12351264
return {'requires_positive_X': True}
12361265

12371266
def _check_X(self, X):
1238-
X = check_array(X, dtype='int', accept_sparse=False,
1239-
force_all_finite=True)
1267+
"""Validate X, used only in predict* methods."""
1268+
X = self._validate_data(X, dtype='int', accept_sparse=False,
1269+
force_all_finite=True, reset=False)
12401270
check_non_negative(X, "CategoricalNB (input X)")
12411271
return X
12421272

1243-
def _check_X_y(self, X, y):
1273+
def _check_X_y(self, X, y, reset=True):
12441274
X, y = self._validate_data(X, y, dtype='int', accept_sparse=False,
1245-
force_all_finite=True)
1275+
force_all_finite=True, reset=reset)
12461276
check_non_negative(X, "CategoricalNB (input X)")
12471277
return X, y
12481278

@@ -1297,7 +1327,7 @@ def _update_cat_count(X_feature, Y, cat_count, n_classes):
12971327
self.class_count_ += Y.sum(axis=0)
12981328
self.n_categories_ = self._validate_n_categories(
12991329
X, self.min_categories)
1300-
for i in range(self.n_features_):
1330+
for i in range(self.n_features_in_):
13011331
X_feature = X[:, i]
13021332
self.category_count_[i] = _update_cat_count_dims(
13031333
self.category_count_[i], self.n_categories_[i] - 1)
@@ -1307,7 +1337,7 @@ def _update_cat_count(X_feature, Y, cat_count, n_classes):
13071337

13081338
def _update_feature_log_prob(self, alpha):
13091339
feature_log_prob = []
1310-
for i in range(self.n_features_):
1340+
for i in range(self.n_features_in_):
13111341
smoothed_cat_count = self.category_count_[i] + alpha
13121342
smoothed_class_count = smoothed_cat_count.sum(axis=1)
13131343
feature_log_prob.append(
@@ -1316,11 +1346,9 @@ def _update_feature_log_prob(self, alpha):
13161346
self.feature_log_prob_ = feature_log_prob
13171347

13181348
def _joint_log_likelihood(self, X):
1319-
if not X.shape[1] == self.n_features_:
1320-
raise ValueError("Expected input with %d features, got %d instead"
1321-
% (self.n_features_, X.shape[1]))
1349+
self._check_n_features(X, reset=False)
13221350
jll = np.zeros((X.shape[0], self.class_count_.shape[0]))
1323-
for i in range(self.n_features_):
1351+
for i in range(self.n_features_in_):
13241352
indices = X[:, i]
13251353
jll += self.feature_log_prob_[i][:, indices].T
13261354
total_ll = jll + self.class_log_prior_

‎sklearn/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_common.py
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ def test_search_cv(estimator, check, request):
273273
'model_selection',
274274
'multiclass',
275275
'multioutput',
276-
'naive_bayes',
277276
'pipeline',
278277
'random_projection',
279278
}

‎sklearn/tests/test_naive_bayes.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_naive_bayes.py
+31-44Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ def test_gnb():
5757
# Test whether label mismatch between target y and classes raises
5858
# an Error
5959
# FIXME Remove this test once the more general partial_fit tests are merged
60-
assert_raises(ValueError, GaussianNB().partial_fit, X, y, classes=[0, 1])
60+
with pytest.raises(
61+
ValueError,
62+
match="The target label.* in y do not exist in the initial classes"
63+
):
64+
GaussianNB().partial_fit(X, y, classes=[0, 1])
6165

6266

6367
# TODO remove in 1.2 once sigma_ attribute is removed (GH #18842)
@@ -74,7 +78,7 @@ def test_gnb_prior():
7478
clf = GaussianNB().fit(X, y)
7579
assert_array_almost_equal(np.array([3, 3]) / 6.0,
7680
clf.class_prior_, 8)
77-
clf.fit(X1, y1)
81+
clf = GaussianNB().fit(X1, y1)
7882
# Check that the class priors sum to 1
7983
assert_array_almost_equal(clf.class_prior_.sum(), 1)
8084

@@ -171,16 +175,6 @@ def test_gnb_check_update_with_no_data():
171175
assert tvar == var
172176

173177

174-
def test_gnb_pfit_wrong_nb_features():
175-
"""Test whether an error is raised when the number of feature changes
176-
between two partial fit"""
177-
clf = GaussianNB()
178-
# Fit for the first time the GNB
179-
clf.fit(X, y)
180-
# Partial fit a second time with an incoherent X
181-
assert_raises(ValueError, clf.partial_fit, np.hstack((X, X)), y)
182-
183-
184178
def test_gnb_partial_fit():
185179
clf = GaussianNB().fit(X, y)
186180
clf_pf = GaussianNB().partial_fit(X, y, np.unique(y))
@@ -272,37 +266,22 @@ def test_discretenb_partial_fit(DiscreteNaiveBayes):
272266

273267

274268
@pytest.mark.parametrize('NaiveBayes', ALL_NAIVE_BAYES_CLASSES)
275-
def test_naive_bayes_input_check_fit(NaiveBayes):
276-
# Test input checks for the fit method
277-
278-
# check shape consistency for number of samples at fit time
279-
assert_raises(ValueError, NaiveBayes().fit, X2, y2[:-1])
280-
281-
# check shape consistency for number of input features at predict time
282-
clf = NaiveBayes().fit(X2, y2)
283-
assert_raises(ValueError, clf.predict, X2[:, :-1])
284-
285-
286-
@pytest.mark.parametrize('DiscreteNaiveBayes', DISCRETE_NAIVE_BAYES_CLASSES)
287-
def test_discretenb_input_check_partial_fit(DiscreteNaiveBayes):
288-
# check shape consistency
289-
assert_raises(ValueError, DiscreteNaiveBayes().partial_fit, X2, y2[:-1],
290-
classes=np.unique(y2))
291-
269+
def test_NB_partial_fit_no_first_classes(NaiveBayes):
292270
# classes is required for first call to partial fit
293-
assert_raises(ValueError, DiscreteNaiveBayes().partial_fit, X2, y2)
271+
with pytest.raises(
272+
ValueError,
273+
match="classes must be passed on the first call to partial_fit."
274+
):
275+
NaiveBayes().partial_fit(X2, y2)
294276

295277
# check consistency of consecutive classes values
296-
clf = DiscreteNaiveBayes()
278+
clf = NaiveBayes()
297279
clf.partial_fit(X2, y2, classes=np.unique(y2))
298-
assert_raises(ValueError, clf.partial_fit, X2, y2,
299-
classes=np.arange(42))
300-
301-
# check consistency of input shape for partial_fit
302-
assert_raises(ValueError, clf.partial_fit, X2[:, :-1], y2)
303-
304-
# check consistency of input shape for predict
305-
assert_raises(ValueError, clf.predict, X2[:, :-1])
280+
with pytest.raises(
281+
ValueError,
282+
match="is not the same as on last call to partial_fit"
283+
):
284+
clf.partial_fit(X2, y2, classes=np.arange(42))
306285

307286

308287
# TODO: Remove in version 1.1
@@ -725,11 +704,6 @@ def test_categoricalnb():
725704
assert_raise_message(ValueError, error_msg, clf.predict, X)
726705
assert_raise_message(ValueError, error_msg, clf.fit, X, y)
727706

728-
# Check error is raised for incorrect X
729-
X = np.array([[1, 4, 1], [2, 5, 6]])
730-
msg = "Expected input with 2 features, got 3 instead"
731-
assert_raise_message(ValueError, msg, clf.predict, X)
732-
733707
# Test alpha
734708
X3_test = np.array([[2, 5]])
735709
# alpha=1 increases the count of all categories by one so the final
@@ -941,3 +915,16 @@ def test_check_accuracy_on_digits():
941915

942916
scores = cross_val_score(GaussianNB(), X_3v8, y_3v8, cv=10)
943917
assert scores.mean() > 0.86
918+
919+
920+
# FIXME: remove in 1.2
921+
@pytest.mark.parametrize("Estimator", DISCRETE_NAIVE_BAYES_CLASSES)
922+
def test_n_features_deprecation(Estimator):
923+
# Check that we raise the proper deprecation warning if accessing
924+
# `n_features_`.
925+
X = np.array([[1, 2], [3, 4]])
926+
y = np.array([1, 0])
927+
est = Estimator().fit(X, y)
928+
929+
with pytest.warns(FutureWarning, match="n_features_ was deprecated"):
930+
est.n_features_

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.