Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5e5a46a

Browse filesBrowse files
thomasjpfanjayzed82
authored andcommitted
ENH Uses _validate_data in other methods in the neural_network module (scikit-learn#18514)
1 parent 47ae20e commit 5e5a46a
Copy full SHA for 5e5a46a

File tree

Expand file treeCollapse file tree

6 files changed

+151
-16
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+151
-16
lines changed

‎sklearn/base.py

Copy file name to clipboardExpand all lines: sklearn/base.py
+23-4Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,10 @@ def _check_n_features(self, X, reset):
360360
If True, the `n_features_in_` attribute is set to `X.shape[1]`.
361361
Else, the attribute must already exist and the function checks
362362
that it is equal to `X.shape[1]`.
363+
.. note::
364+
It is recommended to call reset=True in `fit` and in the first
365+
call to `partial_fit`. All other methods that validate `X`
366+
should set `reset=False`.
363367
"""
364368
n_features = X.shape[1]
365369

@@ -378,7 +382,7 @@ def _check_n_features(self, X, reset):
378382
self.n_features_in_)
379383
)
380384

381-
def _validate_data(self, X, y=None, reset=True,
385+
def _validate_data(self, X, y='no_validation', reset=True,
382386
validate_separately=False, **check_params):
383387
"""Validate input data and set or check the `n_features_in_` attribute.
384388
@@ -387,13 +391,25 @@ def _validate_data(self, X, y=None, reset=True,
387391
X : {array-like, sparse matrix, dataframe} of shape \
388392
(n_samples, n_features)
389393
The input samples.
390-
y : array-like of shape (n_samples,), default=None
391-
The targets. If None, `check_array` is called on `X` and
392-
`check_X_y` is called otherwise.
394+
y : array-like of shape (n_samples,), default='no_validation'
395+
The targets.
396+
397+
- If `None`, `check_array` is called on `X`. If the estimator's
398+
requires_y tag is True, then an error will be raised.
399+
- If `'no_validation'`, `check_array` is called on `X` and the
400+
estimator's requires_y tag is ignored. This is a default
401+
placeholder and is never meant to be explicitly set.
402+
- Otherwise, both `X` and `y` are checked with either `check_array`
403+
or `check_X_y` depending on `validate_separately`.
404+
393405
reset : bool, default=True
394406
Whether to reset the `n_features_in_` attribute.
395407
If False, the input will be checked for consistency with data
396408
provided when reset was last True.
409+
.. note::
410+
It is recommended to call reset=True in `fit` and in the first
411+
call to `partial_fit`. All other methods that validate `X`
412+
should set `reset=False`.
397413
validate_separately : False or tuple of dicts, default=False
398414
Only used if y is not None.
399415
If False, call validate_X_y(). Else, it must be a tuple of kwargs
@@ -417,6 +433,9 @@ def _validate_data(self, X, y=None, reset=True,
417433
)
418434
X = check_array(X, **check_params)
419435
out = X
436+
elif isinstance(y, str) and y == 'no_validation':
437+
X = check_array(X, **check_params)
438+
out = X
420439
else:
421440
if validate_separately:
422441
# We need this because some estimators validate X and y

‎sklearn/neural_network/_multilayer_perceptron.py

Copy file name to clipboardExpand all lines: sklearn/neural_network/_multilayer_perceptron.py
+12-9Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..utils import gen_batches, check_random_state
2323
from ..utils import shuffle
2424
from ..utils import _safe_indexing
25-
from ..utils import check_array, column_or_1d
25+
from ..utils import column_or_1d
2626
from ..exceptions import ConvergenceWarning
2727
from ..utils.extmath import safe_sparse_dot
2828
from ..utils.validation import check_is_fitted, _deprecate_positional_args
@@ -131,7 +131,7 @@ def _forward_pass_fast(self, X):
131131
y_pred : ndarray of shape (n_samples,) or (n_samples, n_outputs)
132132
The decision function of the samples for each class in the model.
133133
"""
134-
X = check_array(X, accept_sparse=['csr', 'csc'])
134+
X = self._validate_data(X, accept_sparse=['csr', 'csc'], reset=False)
135135

136136
# Initialize first layer
137137
activation = X
@@ -358,8 +358,10 @@ def _fit(self, X, y, incremental=False):
358358
if np.any(np.array(hidden_layer_sizes) <= 0):
359359
raise ValueError("hidden_layer_sizes must be > 0, got %s." %
360360
hidden_layer_sizes)
361+
first_pass = (not hasattr(self, 'coefs_') or
362+
(not self.warm_start and not incremental))
361363

362-
X, y = self._validate_input(X, y, incremental)
364+
X, y = self._validate_input(X, y, incremental, reset=first_pass)
363365

364366
n_samples, n_features = X.shape
365367

@@ -375,8 +377,7 @@ def _fit(self, X, y, incremental=False):
375377
# check random state
376378
self._random_state = check_random_state(self.random_state)
377379

378-
if not hasattr(self, 'coefs_') or (not self.warm_start and not
379-
incremental):
380+
if first_pass:
380381
# First time training the model
381382
self._initialize(y, layer_units, X.dtype)
382383

@@ -963,10 +964,11 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu", *,
963964
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon,
964965
n_iter_no_change=n_iter_no_change, max_fun=max_fun)
965966

966-
def _validate_input(self, X, y, incremental):
967+
def _validate_input(self, X, y, incremental, reset):
967968
X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'],
968969
multi_output=True,
969-
dtype=(np.float64, np.float32))
970+
dtype=(np.float64, np.float32),
971+
reset=reset)
970972
if y.ndim == 2 and y.shape[1] == 1:
971973
y = column_or_1d(y, warn=True)
972974

@@ -1409,10 +1411,11 @@ def predict(self, X):
14091411
return y_pred.ravel()
14101412
return y_pred
14111413

1412-
def _validate_input(self, X, y, incremental):
1414+
def _validate_input(self, X, y, incremental, reset):
14131415
X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'],
14141416
multi_output=True, y_numeric=True,
1415-
dtype=(np.float64, np.float32))
1417+
dtype=(np.float64, np.float32),
1418+
reset=reset)
14161419
if y.ndim == 2 and y.shape[1] == 1:
14171420
y = column_or_1d(y, warn=True)
14181421
return X, y

‎sklearn/neural_network/_rbm.py

Copy file name to clipboardExpand all lines: sklearn/neural_network/_rbm.py
+5-2Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ def transform(self, X):
131131
"""
132132
check_is_fitted(self)
133133

134-
X = check_array(X, accept_sparse='csr', dtype=(np.float64, np.float32))
134+
X = self._validate_data(X, accept_sparse='csr', reset=False,
135+
dtype=(np.float64, np.float32))
135136
return self._mean_hiddens(X)
136137

137138
def _mean_hiddens(self, v):
@@ -243,7 +244,9 @@ def partial_fit(self, X, y=None):
243244
self : BernoulliRBM
244245
The fitted model.
245246
"""
246-
X = check_array(X, accept_sparse='csr', dtype=np.float64)
247+
first_pass = not hasattr(self, 'components_')
248+
X = self._validate_data(X, accept_sparse='csr', dtype=np.float64,
249+
reset=first_pass)
247250
if not hasattr(self, 'random_state_'):
248251
self.random_state_ = check_random_state(self.random_state)
249252
if not hasattr(self, 'components_'):

‎sklearn/neural_network/tests/test_mlp.py

Copy file name to clipboardExpand all lines: sklearn/neural_network/tests/test_mlp.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def test_fit():
9494
mlp.intercepts_[1] = np.array([1.0])
9595
mlp._coef_grads = [] * 2
9696
mlp._intercept_grads = [] * 2
97+
mlp.n_features_in_ = 3
9798

9899
# Initialize parameters
99100
mlp.n_iter_ = 0

‎sklearn/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_common.py
+56-1Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
_set_checking_parameters,
3838
_get_check_estimator_ids,
3939
check_class_weight_balanced_linear_classifier,
40-
parametrize_with_checks)
40+
parametrize_with_checks,
41+
check_n_features_in_after_fitting)
4142

4243

4344
def test_all_estimator_no_base_class():
@@ -270,3 +271,57 @@ def test_strict_mode_check_estimator():
270271
def test_strict_mode_parametrize_with_checks(estimator, check):
271272
# Ideally we should assert that the strict checks are Xfailed...
272273
check(estimator)
274+
275+
276+
# TODO: When more modules get added, we can remove it from this list to make
277+
# sure it gets tested. After we finish each module we can move the checks
278+
# into sklearn.utils.estimator_checks.check_n_features_in.
279+
#
280+
# check_estimators_partial_fit_n_features can either be removed or updated
281+
# with the two more assertions:
282+
# 1. `n_features_in_` is set during the first call to `partial_fit`.
283+
# 2. More strict when it comes to the error message.
284+
#
285+
# check_classifiers_train would need to be updated with the error message
286+
N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE = {
287+
'calibration',
288+
'cluster',
289+
'compose',
290+
'covariance',
291+
'cross_decomposition',
292+
'decomposition',
293+
'discriminant_analysis',
294+
'ensemble',
295+
'feature_extraction',
296+
'feature_selection',
297+
'gaussian_process',
298+
'impute',
299+
'isotonic',
300+
'kernel_approximation',
301+
'kernel_ridge',
302+
'linear_model',
303+
'manifold',
304+
'mixture',
305+
'model_selection',
306+
'multiclass',
307+
'multioutput',
308+
'naive_bayes',
309+
'neighbors',
310+
'pipeline',
311+
'preprocessing',
312+
'random_projection',
313+
'semi_supervised',
314+
'svm',
315+
'tree',
316+
}
317+
318+
N_FEATURES_IN_AFTER_FIT_ESTIMATORS = [
319+
est for est in _tested_estimators() if est.__module__.split('.')[1] not in
320+
N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE
321+
]
322+
323+
324+
@pytest.mark.parametrize("estimator", N_FEATURES_IN_AFTER_FIT_ESTIMATORS,
325+
ids=_get_check_estimator_ids)
326+
def test_check_n_features_in_after_fitting(estimator):
327+
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)

‎sklearn/utils/estimator_checks.py

Copy file name to clipboardExpand all lines: sklearn/utils/estimator_checks.py
+54Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3111,6 +3111,60 @@ def check_requires_y_none(name, estimator_orig, strict_mode=True):
31113111
warnings.warn(warning_msg, FutureWarning)
31123112

31133113

3114+
def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True):
3115+
# Make sure that n_features_in are checked after fitting
3116+
tags = estimator_orig._get_tags()
3117+
3118+
if "2darray" not in tags["X_types"] or tags["no_validation"]:
3119+
return
3120+
3121+
rng = np.random.RandomState(0)
3122+
3123+
estimator = clone(estimator_orig)
3124+
set_random_state(estimator)
3125+
if 'warm_start' in estimator.get_params():
3126+
estimator.set_params(warm_start=False)
3127+
3128+
n_samples = 100
3129+
X = rng.normal(loc=100, size=(n_samples, 2))
3130+
X = _pairwise_estimator_convert_X(X, estimator)
3131+
if is_regressor(estimator):
3132+
y = rng.normal(size=n_samples)
3133+
else:
3134+
y = rng.randint(low=0, high=2, size=n_samples)
3135+
y = _enforce_estimator_tags_y(estimator, y)
3136+
3137+
estimator.fit(X, y)
3138+
assert estimator.n_features_in_ == X.shape[1]
3139+
3140+
# check methods will check n_features_in_
3141+
check_methods = ["predict", "transform", "decision_function",
3142+
"predict_proba"]
3143+
X_bad = X[:, [1]]
3144+
3145+
msg = (f"X has 1 features, but {name} is expecting {X.shape[1]} "
3146+
"features as input")
3147+
for method in check_methods:
3148+
if not hasattr(estimator, method):
3149+
continue
3150+
with raises(ValueError, match=msg):
3151+
getattr(estimator, method)(X_bad)
3152+
3153+
# partial_fit will check in the second call
3154+
if not hasattr(estimator, "partial_fit"):
3155+
return
3156+
3157+
estimator = clone(estimator_orig)
3158+
if is_classifier(estimator):
3159+
estimator.partial_fit(X, y, classes=np.unique(y))
3160+
else:
3161+
estimator.partial_fit(X, y)
3162+
assert estimator.n_features_in_ == X.shape[1]
3163+
3164+
with raises(ValueError, match=msg):
3165+
estimator.partial_fit(X_bad, y)
3166+
3167+
31143168
# set of checks that are completely strict, i.e. they have no non-strict part
31153169
_FULLY_STRICT_CHECKS = set([
31163170
'check_n_features_in',

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.