Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 054d156

Browse filesBrowse files
authored
ENH Checks n_features_in in neighbors (#18744)
1 parent 74f20ae commit 054d156
Copy full SHA for 054d156

File tree

6 files changed

+10
-13
lines changed
Filter options

6 files changed

+10
-13
lines changed

‎sklearn/neighbors/_base.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/_base.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ class from an array representing our data set and ask who's
667667
if self.effective_metric_ == 'precomputed':
668668
X = _check_precomputed(X)
669669
else:
670-
X = check_array(X, accept_sparse='csr')
670+
X = self._validate_data(X, accept_sparse='csr', reset=False)
671671
else:
672672
query_is_train = True
673673
X = self._fit_X
@@ -982,7 +982,7 @@ class from an array representing our data set and ask who's
982982
if self.effective_metric_ == 'precomputed':
983983
X = _check_precomputed(X)
984984
else:
985-
X = check_array(X, accept_sparse='csr')
985+
X = self._validate_data(X, accept_sparse='csr', reset=False)
986986
else:
987987
query_is_train = True
988988
X = self._fit_X

‎sklearn/neighbors/_classification.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/_classification.py
+3-4Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from ._base import _check_weights, _get_weights
1818
from ._base import NeighborsBase, KNeighborsMixin, RadiusNeighborsMixin
1919
from ..base import ClassifierMixin
20-
from ..utils import check_array
2120
from ..utils.validation import _deprecate_positional_args
2221

2322

@@ -192,7 +191,7 @@ def predict(self, X):
192191
y : ndarray of shape (n_queries,) or (n_queries, n_outputs)
193192
Class labels for each data sample.
194193
"""
195-
X = check_array(X, accept_sparse='csr')
194+
X = self._validate_data(X, accept_sparse='csr', reset=False)
196195

197196
neigh_dist, neigh_ind = self.kneighbors(X)
198197
classes_ = self.classes_
@@ -236,7 +235,7 @@ def predict_proba(self, X):
236235
The class probabilities of the input samples. Classes are ordered
237236
by lexicographic order.
238237
"""
239-
X = check_array(X, accept_sparse='csr')
238+
X = self._validate_data(X, accept_sparse='csr', reset=False)
240239

241240
neigh_dist, neigh_ind = self.kneighbors(X)
242241

@@ -545,7 +544,7 @@ def predict_proba(self, X):
545544
by lexicographic order.
546545
"""
547546

548-
X = check_array(X, accept_sparse='csr')
547+
X = self._validate_data(X, accept_sparse='csr', reset=False)
549548
n_queries = _num_samples(X)
550549

551550
neigh_dist, neigh_ind = self.radius_neighbors(X)

‎sklearn/neighbors/_nca.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/_nca.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def transform(self, X):
263263
"""
264264

265265
check_is_fitted(self)
266-
X = check_array(X)
266+
X = self._validate_data(X, reset=False)
267267

268268
return np.dot(X, self.components_.T)
269269

‎sklearn/neighbors/_nearest_centroid.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/_nearest_centroid.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..base import BaseEstimator, ClassifierMixin
1616
from ..metrics.pairwise import pairwise_distances
1717
from ..preprocessing import LabelEncoder
18-
from ..utils.validation import check_array, check_is_fitted
18+
from ..utils.validation import check_is_fitted
1919
from ..utils.validation import _deprecate_positional_args
2020
from ..utils.sparsefuncs import csc_median_axis_0
2121
from ..utils.multiclass import check_classification_targets
@@ -201,6 +201,6 @@ def predict(self, X):
201201
"""
202202
check_is_fitted(self)
203203

204-
X = check_array(X, accept_sparse='csr')
204+
X = self._validate_data(X, accept_sparse='csr', reset=False)
205205
return self.classes_[pairwise_distances(
206206
X, self.centroids_, metric=self.metric).argmin(axis=1)]

‎sklearn/neighbors/_regression.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/_regression.py
+2-3Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from ._base import _get_weights, _check_weights
1818
from ._base import NeighborsBase, KNeighborsMixin, RadiusNeighborsMixin
1919
from ..base import RegressorMixin
20-
from ..utils import check_array
2120
from ..utils.validation import _deprecate_positional_args
2221
from ..utils.deprecation import deprecated
2322

@@ -203,7 +202,7 @@ def predict(self, X):
203202
y : ndarray of shape (n_queries,) or (n_queries, n_outputs), dtype=int
204203
Target values.
205204
"""
206-
X = check_array(X, accept_sparse='csr')
205+
X = self._validate_data(X, accept_sparse='csr', reset=False)
207206

208207
neigh_dist, neigh_ind = self.kneighbors(X)
209208

@@ -392,7 +391,7 @@ def predict(self, X):
392391
dtype=double
393392
Target values.
394393
"""
395-
X = check_array(X, accept_sparse='csr')
394+
X = self._validate_data(X, accept_sparse='csr', reset=False)
396395

397396
neigh_dist, neigh_ind = self.radius_neighbors(X)
398397

‎sklearn/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_common.py
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ def test_search_cv(estimator, check, request):
279279
'multiclass',
280280
'multioutput',
281281
'naive_bayes',
282-
'neighbors',
283282
'pipeline',
284283
'random_projection',
285284
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.