Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit bcc6430

Browse filesBrowse files
authored
ENH Make KNeighborsClassifier.predict handle X=None (scikit-learn#30047)
1 parent c08b433 commit bcc6430
Copy full SHA for bcc6430

File tree

Expand file treeCollapse file tree

4 files changed

+143
-16
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+143
-16
lines changed
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- Make `predict`, `predict_proba`, and `score` of
2+
:class:`neighbors.KNeighborsClassifier` and
3+
:class:`neighbors.RadiusNeighborsClassifier` accept `X=None` as input. In this case
4+
predictions for all training set points are returned, and points are not included
5+
into their own neighbors.
6+
:pr:`30047` by :user:`Dmitry Kobak <dkobak>`.

‎sklearn/neighbors/_classification.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/_classification.py
+85-11Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,10 @@ def predict(self, X):
244244
Parameters
245245
----------
246246
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
247-
or (n_queries, n_indexed) if metric == 'precomputed'
248-
Test samples.
247+
or (n_queries, n_indexed) if metric == 'precomputed', or None
248+
Test samples. If `None`, predictions for all indexed points are
249+
returned; in this case, points are not considered their own
250+
neighbors.
249251
250252
Returns
251253
-------
@@ -281,7 +283,7 @@ def predict(self, X):
281283
classes_ = [self.classes_]
282284

283285
n_outputs = len(classes_)
284-
n_queries = _num_samples(X)
286+
n_queries = _num_samples(self._fit_X if X is None else X)
285287
weights = _get_weights(neigh_dist, self.weights)
286288
if weights is not None and _all_with_any_reduction_axis_1(weights, value=0):
287289
raise ValueError(
@@ -311,8 +313,10 @@ def predict_proba(self, X):
311313
Parameters
312314
----------
313315
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
314-
or (n_queries, n_indexed) if metric == 'precomputed'
315-
Test samples.
316+
or (n_queries, n_indexed) if metric == 'precomputed', or None
317+
Test samples. If `None`, predictions for all indexed points are
318+
returned; in this case, points are not considered their own
319+
neighbors.
316320
317321
Returns
318322
-------
@@ -375,7 +379,7 @@ def predict_proba(self, X):
375379
_y = self._y.reshape((-1, 1))
376380
classes_ = [self.classes_]
377381

378-
n_queries = _num_samples(X)
382+
n_queries = _num_samples(self._fit_X if X is None else X)
379383

380384
weights = _get_weights(neigh_dist, self.weights)
381385
if weights is None:
@@ -408,6 +412,39 @@ def predict_proba(self, X):
408412

409413
return probabilities
410414

415+
# This function is defined here only to modify the parent docstring
416+
# and add information about X=None
417+
def score(self, X, y, sample_weight=None):
418+
"""
419+
Return the mean accuracy on the given test data and labels.
420+
421+
In multi-label classification, this is the subset accuracy
422+
which is a harsh metric since you require for each sample that
423+
each label set be correctly predicted.
424+
425+
Parameters
426+
----------
427+
X : array-like of shape (n_samples, n_features), or None
428+
Test samples. If `None`, predictions for all indexed points are
429+
used; in this case, points are not considered their own
430+
neighbors. This means that `knn.fit(X, y).score(None, y)`
431+
implicitly performs a leave-one-out cross-validation procedure
432+
and is equivalent to `cross_val_score(knn, X, y, cv=LeaveOneOut())`
433+
but typically much faster.
434+
435+
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
436+
True labels for `X`.
437+
438+
sample_weight : array-like of shape (n_samples,), default=None
439+
Sample weights.
440+
441+
Returns
442+
-------
443+
score : float
444+
Mean accuracy of ``self.predict(X)`` w.r.t. `y`.
445+
"""
446+
return super().score(X, y, sample_weight)
447+
411448
def __sklearn_tags__(self):
412449
tags = super().__sklearn_tags__()
413450
tags.classifier_tags.multi_label = True
@@ -692,8 +729,10 @@ def predict(self, X):
692729
Parameters
693730
----------
694731
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
695-
or (n_queries, n_indexed) if metric == 'precomputed'
696-
Test samples.
732+
or (n_queries, n_indexed) if metric == 'precomputed', or None
733+
Test samples. If `None`, predictions for all indexed points are
734+
returned; in this case, points are not considered their own
735+
neighbors.
697736
698737
Returns
699738
-------
@@ -734,8 +773,10 @@ def predict_proba(self, X):
734773
Parameters
735774
----------
736775
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
737-
or (n_queries, n_indexed) if metric == 'precomputed'
738-
Test samples.
776+
or (n_queries, n_indexed) if metric == 'precomputed', or None
777+
Test samples. If `None`, predictions for all indexed points are
778+
returned; in this case, points are not considered their own
779+
neighbors.
739780
740781
Returns
741782
-------
@@ -745,7 +786,7 @@ def predict_proba(self, X):
745786
by lexicographic order.
746787
"""
747788
check_is_fitted(self, "_fit_method")
748-
n_queries = _num_samples(X)
789+
n_queries = _num_samples(self._fit_X if X is None else X)
749790

750791
metric, metric_kwargs = _adjusted_metric(
751792
metric=self.metric, metric_kwargs=self.metric_params, p=self.p
@@ -846,6 +887,39 @@ def predict_proba(self, X):
846887

847888
return probabilities
848889

890+
# This function is defined here only to modify the parent docstring
891+
# and add information about X=None
892+
def score(self, X, y, sample_weight=None):
893+
"""
894+
Return the mean accuracy on the given test data and labels.
895+
896+
In multi-label classification, this is the subset accuracy
897+
which is a harsh metric since you require for each sample that
898+
each label set be correctly predicted.
899+
900+
Parameters
901+
----------
902+
X : array-like of shape (n_samples, n_features), or None
903+
Test samples. If `None`, predictions for all indexed points are
904+
used; in this case, points are not considered their own
905+
neighbors. This means that `knn.fit(X, y).score(None, y)`
906+
implicitly performs a leave-one-out cross-validation procedure
907+
and is equivalent to `cross_val_score(knn, X, y, cv=LeaveOneOut())`
908+
but typically much faster.
909+
910+
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
911+
True labels for `X`.
912+
913+
sample_weight : array-like of shape (n_samples,), default=None
914+
Sample weights.
915+
916+
Returns
917+
-------
918+
score : float
919+
Mean accuracy of ``self.predict(X)`` w.r.t. `y`.
920+
"""
921+
return super().score(X, y, sample_weight)
922+
849923
def __sklearn_tags__(self):
850924
tags = super().__sklearn_tags__()
851925
tags.classifier_tags.multi_label = True

‎sklearn/neighbors/_regression.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/_regression.py
+8-4Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,10 @@ def predict(self, X):
234234
Parameters
235235
----------
236236
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
237-
or (n_queries, n_indexed) if metric == 'precomputed'
238-
Test samples.
237+
or (n_queries, n_indexed) if metric == 'precomputed', or None
238+
Test samples. If `None`, predictions for all indexed points are
239+
returned; in this case, points are not considered their own
240+
neighbors.
239241
240242
Returns
241243
-------
@@ -464,8 +466,10 @@ def predict(self, X):
464466
Parameters
465467
----------
466468
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
467-
or (n_queries, n_indexed) if metric == 'precomputed'
468-
Test samples.
469+
or (n_queries, n_indexed) if metric == 'precomputed', or None
470+
Test samples. If `None`, predictions for all indexed points are
471+
returned; in this case, points are not considered their own
472+
neighbors.
469473
470474
Returns
471475
-------

‎sklearn/neighbors/tests/test_neighbors.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/tests/test_neighbors.py
+44-1Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
assert_compatible_argkmin_results,
2525
assert_compatible_radius_results,
2626
)
27-
from sklearn.model_selection import cross_val_score, train_test_split
27+
from sklearn.model_selection import (
28+
LeaveOneOut,
29+
cross_val_predict,
30+
cross_val_score,
31+
train_test_split,
32+
)
2833
from sklearn.neighbors import (
2934
VALID_METRICS_SPARSE,
3035
KNeighborsRegressor,
@@ -2390,3 +2395,41 @@ def _weights(dist):
23902395

23912396
with pytest.raises(ValueError, match=msg):
23922397
est.predict_proba([[1.1, 1.1]])
2398+
2399+
2400+
@pytest.mark.parametrize(
2401+
"nn_model",
2402+
[
2403+
neighbors.KNeighborsClassifier(n_neighbors=10),
2404+
neighbors.RadiusNeighborsClassifier(radius=5.0),
2405+
],
2406+
)
2407+
def test_neighbor_classifiers_loocv(nn_model):
2408+
"""Check that `predict` and related functions work fine with X=None"""
2409+
X, y = datasets.make_blobs(n_samples=500, centers=5, n_features=2, random_state=0)
2410+
2411+
loocv = cross_val_score(nn_model, X, y, cv=LeaveOneOut())
2412+
nn_model.fit(X, y)
2413+
2414+
assert np.all(loocv == (nn_model.predict(None) == y))
2415+
assert np.mean(loocv) == nn_model.score(None, y)
2416+
assert nn_model.score(None, y) < nn_model.score(X, y)
2417+
2418+
2419+
@pytest.mark.parametrize(
2420+
"nn_model",
2421+
[
2422+
neighbors.KNeighborsRegressor(n_neighbors=10),
2423+
neighbors.RadiusNeighborsRegressor(radius=0.5),
2424+
],
2425+
)
2426+
def test_neighbor_regressors_loocv(nn_model):
2427+
"""Check that `predict` and related functions work fine with X=None"""
2428+
X, y = datasets.load_diabetes(return_X_y=True)
2429+
2430+
# Only checking cross_val_predict and not cross_val_score because
2431+
# cross_val_score does not work with LeaveOneOut() for a regressor
2432+
loocv = cross_val_predict(nn_model, X, y, cv=LeaveOneOut())
2433+
nn_model.fit(X, y)
2434+
2435+
assert np.all(loocv == nn_model.predict(None))

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.