Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5946f8b

Browse filesBrowse files
thomasjpfanogrisel
andauthored
ENH Adds n_features_in_ checks to linear and svm modules (#18578)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com>
1 parent 8f72c2a commit 5946f8b
Copy full SHA for 5946f8b

File tree

5 files changed

+21
-32
lines changed
Filter options

5 files changed

+21
-32
lines changed

‎sklearn/linear_model/_base.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_base.py
+3-8Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ def fit(self, X, y):
217217
def _decision_function(self, X):
218218
check_is_fitted(self)
219219

220-
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
220+
X = self._validate_data(X, accept_sparse=['csr', 'csc', 'coo'],
221+
reset=False)
221222
return safe_sparse_dot(X, self.coef_.T,
222223
dense_output=True) + self.intercept_
223224

@@ -281,13 +282,7 @@ class would be predicted.
281282
"""
282283
check_is_fitted(self)
283284

284-
X = check_array(X, accept_sparse='csr')
285-
286-
n_features = self.coef_.shape[1]
287-
if X.shape[1] != n_features:
288-
raise ValueError("X has %d features per sample; expecting %d"
289-
% (X.shape[1], n_features))
290-
285+
X = self._validate_data(X, accept_sparse='csr', reset=False)
291286
scores = safe_sparse_dot(X, self.coef_.T,
292287
dense_output=True) + self.intercept_
293288
return scores.ravel() if scores.shape[1] == 1 else scores

‎sklearn/linear_model/_glm/glm.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_glm/glm.py
+6-7Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import scipy.optimize
1313

1414
from ...base import BaseEstimator, RegressorMixin
15-
from ...utils import check_array, check_X_y
1615
from ...utils.optimize import _check_optimize_result
1716
from ...utils.validation import check_is_fitted, _check_sample_weight
1817
from ..._loss.glm_distribution import (
@@ -221,9 +220,9 @@ def fit(self, X, y, sample_weight=None):
221220
family = self._family_instance
222221
link = self._link_instance
223222

224-
X, y = check_X_y(X, y, accept_sparse=['csc', 'csr'],
225-
dtype=[np.float64, np.float32],
226-
y_numeric=True, multi_output=False)
223+
X, y = self._validate_data(X, y, accept_sparse=['csc', 'csr'],
224+
dtype=[np.float64, np.float32],
225+
y_numeric=True, multi_output=False)
227226

228227
weights = _check_sample_weight(sample_weight, X)
229228

@@ -311,9 +310,9 @@ def _linear_predictor(self, X):
311310
Returns predicted values of linear predictor.
312311
"""
313312
check_is_fitted(self)
314-
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'],
315-
dtype=[np.float64, np.float32], ensure_2d=True,
316-
allow_nd=False)
313+
X = self._validate_data(X, accept_sparse=['csr', 'csc', 'coo'],
314+
dtype=[np.float64, np.float32], ensure_2d=True,
315+
allow_nd=False, reset=False)
317316
return X @ self.coef_ + self.intercept_
318317

319318
def predict(self, X):

‎sklearn/linear_model/_stochastic_gradient.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_stochastic_gradient.py
+9-9Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ._base import LinearClassifierMixin, SparseCoefMixin
1616
from ._base import make_dataset
1717
from ..base import BaseEstimator, RegressorMixin
18-
from ..utils import check_array, check_random_state, check_X_y
18+
from ..utils import check_random_state
1919
from ..utils.extmath import safe_sparse_dot
2020
from ..utils.multiclass import _check_partial_fit_first_call
2121
from ..utils.validation import check_is_fitted, _check_sample_weight
@@ -491,8 +491,10 @@ def _partial_fit(self, X, y, alpha, C,
491491
loss, learning_rate, max_iter,
492492
classes, sample_weight,
493493
coef_init, intercept_init):
494-
X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64,
495-
order="C", accept_large_sparse=False)
494+
first_call = not hasattr(self, "classes_")
495+
X, y = self._validate_data(X, y, accept_sparse='csr', dtype=np.float64,
496+
order="C", accept_large_sparse=False,
497+
reset=first_call)
496498

497499
n_samples, n_features = X.shape
498500

@@ -1138,22 +1140,20 @@ def __init__(self, loss="squared_loss", *, penalty="l2", alpha=0.0001,
11381140

11391141
def _partial_fit(self, X, y, alpha, C, loss, learning_rate,
11401142
max_iter, sample_weight, coef_init, intercept_init):
1143+
first_call = getattr(self, "coef_", None) is None
11411144
X, y = self._validate_data(X, y, accept_sparse="csr", copy=False,
11421145
order='C', dtype=np.float64,
1143-
accept_large_sparse=False)
1146+
accept_large_sparse=False, reset=first_call)
11441147
y = y.astype(np.float64, copy=False)
11451148

11461149
n_samples, n_features = X.shape
11471150

11481151
sample_weight = _check_sample_weight(sample_weight, X)
11491152

11501153
# Allocate datastructures from input arguments
1151-
if getattr(self, "coef_", None) is None:
1154+
if first_call:
11521155
self._allocate_parameter_mem(1, n_features, coef_init,
11531156
intercept_init)
1154-
elif n_features != self.coef_.shape[-1]:
1155-
raise ValueError("Number of features %d does not match previous "
1156-
"data %d." % (n_features, self.coef_.shape[-1]))
11571157
if self.average > 0 and getattr(self, "_average_coef", None) is None:
11581158
self._average_coef = np.zeros(n_features,
11591159
dtype=np.float64,
@@ -1269,7 +1269,7 @@ def _decision_function(self, X):
12691269
"""
12701270
check_is_fitted(self)
12711271

1272-
X = check_array(X, accept_sparse='csr')
1272+
X = self._validate_data(X, accept_sparse='csr', reset=False)
12731273

12741274
scores = safe_sparse_dot(X, self.coef_.T,
12751275
dense_output=True) + self.intercept_

‎sklearn/svm/_base.py

Copy file name to clipboardExpand all lines: sklearn/svm/_base.py
+3-6Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,9 @@ def _validate_for_predict(self, X):
471471
check_is_fitted(self)
472472

473473
if not callable(self.kernel):
474-
X = check_array(X, accept_sparse='csr', dtype=np.float64,
475-
order="C", accept_large_sparse=False)
474+
X = self._validate_data(X, accept_sparse='csr', dtype=np.float64,
475+
order="C", accept_large_sparse=False,
476+
reset=False)
476477

477478
if self._sparse and not sp.isspmatrix(X):
478479
X = sp.csr_matrix(X)
@@ -489,10 +490,6 @@ def _validate_for_predict(self, X):
489490
raise ValueError("X.shape[1] = %d should be equal to %d, "
490491
"the number of samples at training time" %
491492
(X.shape[1], self.shape_fit_[0]))
492-
elif not callable(self.kernel) and X.shape[1] != self.shape_fit_[1]:
493-
raise ValueError("X.shape[1] = %d should be equal to %d, "
494-
"the number of features at training time" %
495-
(X.shape[1], self.shape_fit_[1]))
496493
return X
497494

498495
@property

‎sklearn/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_common.py
-2Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ def test_search_cv(estimator, check, request):
273273
'feature_extraction',
274274
'feature_selection',
275275
'isotonic',
276-
'linear_model',
277276
'manifold',
278277
'mixture',
279278
'model_selection',
@@ -284,7 +283,6 @@ def test_search_cv(estimator, check, request):
284283
'pipeline',
285284
'random_projection',
286285
'semi_supervised',
287-
'svm',
288286
}
289287

290288
N_FEATURES_IN_AFTER_FIT_ESTIMATORS = [

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.