From 295452da37c631480d44fb4423ea89ee98db0444 Mon Sep 17 00:00:00 2001 From: Dmitry Kobak Date: Fri, 11 Oct 2024 15:48:36 +0200 Subject: [PATCH 1/7] Formatting fix --- sklearn/neighbors/_classification.py | 96 ++++++++++++++++++++++++---- sklearn/neighbors/_regression.py | 12 ++-- 2 files changed, 93 insertions(+), 15 deletions(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index b63381af84602..5f44a0ecca603 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -244,8 +244,10 @@ def predict(self, X): Parameters ---------- X : {array-like, sparse matrix} of shape (n_queries, n_features), \ - or (n_queries, n_indexed) if metric == 'precomputed' - Test samples. + or (n_queries, n_indexed) if metric == 'precomputed', or None + Test samples. If `None`, predictions for all indexed points are + returned; in this case, points are not considered their own + neighbors. Returns ------- @@ -281,7 +283,7 @@ def predict(self, X): classes_ = [self.classes_] n_outputs = len(classes_) - n_queries = _num_samples(X) + n_queries = _num_samples(self._fit_X if X is None else X) weights = _get_weights(neigh_dist, self.weights) if weights is not None and _all_with_any_reduction_axis_1(weights, value=0): raise ValueError( @@ -311,8 +313,10 @@ def predict_proba(self, X): Parameters ---------- X : {array-like, sparse matrix} of shape (n_queries, n_features), \ - or (n_queries, n_indexed) if metric == 'precomputed' - Test samples. + or (n_queries, n_indexed) if metric == 'precomputed', or None + Test samples. If `None`, predictions for all indexed points are + returned; in this case, points are not considered their own + neighbors. Returns ------- @@ -375,7 +379,7 @@ def predict_proba(self, X): _y = self._y.reshape((-1, 1)) classes_ = [self.classes_] - n_queries = _num_samples(X) + n_queries = _num_samples(self._fit_X if X is None else X) weights = _get_weights(neigh_dist, self.weights) if weights is None: @@ -408,6 +412,39 @@ def predict_proba(self, X): return probabilities + # This function is defined here only to modify the parent docstring + # and add information about X=None + def score(self, X, y, sample_weight=None): + """ + Return the mean accuracy on the given test data and labels. + + In multi-label classification, this is the subset accuracy + which is a harsh metric since you require for each sample that + each label set be correctly predicted. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features), or None + Test samples. If `None`, predictions for all indexed points are + used; in this case, points are not considered their own + neighbors. This means that `knn.fit(X, y).score(None, y)` + implicitly performs a leave-one-out cross-validation procedure + and is equivalent to `cross_val_score(knn, X, y, cv=LeaveOneOut())` + but typically much faster. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + True labels for `X`. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + score : float + Mean accuracy of ``self.predict(X)`` w.r.t. `y`. + """ + return super().score(X, y, sample_weight) + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.classifier_tags.multi_label = True @@ -692,8 +729,10 @@ def predict(self, X): Parameters ---------- X : {array-like, sparse matrix} of shape (n_queries, n_features), \ - or (n_queries, n_indexed) if metric == 'precomputed' - Test samples. + or (n_queries, n_indexed) if metric == 'precomputed', or None + Test samples. If `None`, predictions for all indexed points are + returned; in this case, points are not considered their own + neighbors. Returns ------- @@ -734,8 +773,10 @@ def predict_proba(self, X): Parameters ---------- X : {array-like, sparse matrix} of shape (n_queries, n_features), \ - or (n_queries, n_indexed) if metric == 'precomputed' - Test samples. + or (n_queries, n_indexed) if metric == 'precomputed', or None + Test samples. If `None`, predictions for all indexed points are + returned; in this case, points are not considered their own + neighbors. Returns ------- @@ -745,7 +786,7 @@ def predict_proba(self, X): by lexicographic order. """ check_is_fitted(self, "_fit_method") - n_queries = _num_samples(X) + n_queries = _num_samples(self._fit_X if X is None else X) metric, metric_kwargs = _adjusted_metric( metric=self.metric, metric_kwargs=self.metric_params, p=self.p @@ -846,6 +887,39 @@ def predict_proba(self, X): return probabilities + # This function is defined here only to modify the parent docstring + # and add information about X=None + def score(self, X, y, sample_weight=None): + """ + Return the mean accuracy on the given test data and labels. + + In multi-label classification, this is the subset accuracy + which is a harsh metric since you require for each sample that + each label set be correctly predicted. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features), or None + Test samples. If `None`, predictions for all indexed points are + used; in this case, points are not considered their own + neighbors. This means that `knn.fit(X, y).score(None, y)` + implicitly performs a leave-one-out cross-validation procedure + and is equivalent to `cross_val_score(knn, X, y, cv=LeaveOneOut())` + but typically much faster. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + True labels for `X`. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + score : float + Mean accuracy of ``self.predict(X)`` w.r.t. `y`. + """ + return super().score(X, y, sample_weight) + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.classifier_tags.multi_label = True diff --git a/sklearn/neighbors/_regression.py b/sklearn/neighbors/_regression.py index 8410a140b9eb1..f324d3fb7e2f2 100644 --- a/sklearn/neighbors/_regression.py +++ b/sklearn/neighbors/_regression.py @@ -234,8 +234,10 @@ def predict(self, X): Parameters ---------- X : {array-like, sparse matrix} of shape (n_queries, n_features), \ - or (n_queries, n_indexed) if metric == 'precomputed' - Test samples. + or (n_queries, n_indexed) if metric == 'precomputed', or None + Test samples. If `None`, predictions for all indexed points are + returned; in this case, points are not considered their own + neighbors. Returns ------- @@ -464,8 +466,10 @@ def predict(self, X): Parameters ---------- X : {array-like, sparse matrix} of shape (n_queries, n_features), \ - or (n_queries, n_indexed) if metric == 'precomputed' - Test samples. + or (n_queries, n_indexed) if metric == 'precomputed', or None + Test samples. If `None`, predictions for all indexed points are + returned; in this case, points are not considered their own + neighbors. Returns ------- From 15447167be5670d7745503cd37bb801728d6c747 Mon Sep 17 00:00:00 2001 From: Dmitry Kobak Date: Fri, 11 Oct 2024 16:20:30 +0200 Subject: [PATCH 2/7] Fix formatting --- sklearn/neighbors/tests/test_neighbors.py | 43 ++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index cb6acb65cb1cc..f2af7497d0819 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -24,7 +24,12 @@ assert_compatible_argkmin_results, assert_compatible_radius_results, ) -from sklearn.model_selection import cross_val_score, train_test_split +from sklearn.model_selection import ( + LeaveOneOut, + cross_val_predict, + cross_val_score, + train_test_split, +) from sklearn.neighbors import ( VALID_METRICS_SPARSE, KNeighborsRegressor, @@ -2390,3 +2395,39 @@ def _weights(dist): with pytest.raises(ValueError, match=msg): est.predict_proba([[1.1, 1.1]]) + + +def test_neighbor_classifiers_loocv(): + """Check that `predict` and related functions work fine with X=None""" + X, y = datasets.make_blobs(n_samples=500, centers=5, n_features=2, random_state=0) + + models = [ + neighbors.KNeighborsClassifier(n_neighbors=10), + neighbors.RadiusNeighborsClassifier(radius=5.0), + ] + + for knn in models: + loocv = cross_val_score(knn, X, y, cv=LeaveOneOut()) + + knn.fit(X, y) + + assert np.all(loocv == (knn.predict(None) == y)) + assert np.mean(loocv) == knn.score(None, y) + assert knn.score(None, y) < knn.score(X, y) + + +def test_neighbor_regressors_loocv(): + """Check that `predict` and related functions work fine with X=None""" + X, y = datasets.load_diabetes(return_X_y=True) + + models = [ + neighbors.KNeighborsClassifier(n_neighbors=10), + neighbors.RadiusNeighborsClassifier(radius=0.5), + ] + + for knn in models: + loocv = cross_val_predict(knn, X, y, cv=LeaveOneOut()) + + knn.fit(X, y) + + assert np.all(loocv == knn.predict(None)) From cb07cb1f92c772cec686a02517a6f550714eae68 Mon Sep 17 00:00:00 2001 From: Dmitry Kobak Date: Fri, 11 Oct 2024 16:23:09 +0200 Subject: [PATCH 3/7] Add a comment --- sklearn/neighbors/tests/test_neighbors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index f2af7497d0819..c1dfcacfc94bf 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -2426,6 +2426,8 @@ def test_neighbor_regressors_loocv(): ] for knn in models: + # Only checking cross_val_predict and not cross_val_score because + # cross_val_score does not work with LeaveOneOut() for a regressor loocv = cross_val_predict(knn, X, y, cv=LeaveOneOut()) knn.fit(X, y) From 933192a545052481501982764485757e55909736 Mon Sep 17 00:00:00 2001 From: Dmitry Kobak Date: Fri, 11 Oct 2024 17:52:24 +0200 Subject: [PATCH 4/7] Changelog --- doc/whats_new/v1.6.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 1c87d7a4893a6..96be94a1c73bd 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -361,6 +361,13 @@ Changelog when duplicate values in the training data lead to inaccurate outlier detection. :pr:`28773` by :user:`Henrique Caroço `. +- |Enhancement| Make `predict()`, `predict_proba()`, and `score()` of + :class:`neighbors.KNeighborsClassifier` and + :class:`neighbors.RadiusNeighborsClassifier` accept `X=None` as input. In this case + predictions for all training set points are returned, and points are not included + into their own neighbors. + :pr:`30047` by :user:`Dmitry Kobak `. + :mod:`sklearn.preprocessing` ............................ From 87b58357f18924dcfcdd3e0e7665db27e05a972b Mon Sep 17 00:00:00 2001 From: Dmitry Kobak Date: Fri, 18 Oct 2024 12:43:46 +0200 Subject: [PATCH 5/7] Update doc/whats_new/v1.6.rst Co-authored-by: Omar Salman --- doc/whats_new/v1.6.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 96be94a1c73bd..484e743a878bc 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -361,7 +361,7 @@ Changelog when duplicate values in the training data lead to inaccurate outlier detection. :pr:`28773` by :user:`Henrique Caroço `. -- |Enhancement| Make `predict()`, `predict_proba()`, and `score()` of +- |Enhancement| Make `predict`, `predict_proba`, and `score` of :class:`neighbors.KNeighborsClassifier` and :class:`neighbors.RadiusNeighborsClassifier` accept `X=None` as input. In this case predictions for all training set points are returned, and points are not included From e61f4f4b005c2236c42d06ca87a48c4a7d58f1a2 Mon Sep 17 00:00:00 2001 From: Dmitry Kobak Date: Fri, 18 Oct 2024 12:50:57 +0200 Subject: [PATCH 6/7] Fix the test --- sklearn/neighbors/tests/test_neighbors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index c1dfcacfc94bf..46ee82036fa7c 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -2421,8 +2421,8 @@ def test_neighbor_regressors_loocv(): X, y = datasets.load_diabetes(return_X_y=True) models = [ - neighbors.KNeighborsClassifier(n_neighbors=10), - neighbors.RadiusNeighborsClassifier(radius=0.5), + neighbors.KNeighborsRegressor(n_neighbors=10), + neighbors.RadiusNeighborsRegressor(radius=0.5), ] for knn in models: From 6fc430b921980243659daf572653601e8e801e35 Mon Sep 17 00:00:00 2001 From: Dmitry Kobak Date: Fri, 18 Oct 2024 14:29:09 +0200 Subject: [PATCH 7/7] Paramterize tests --- sklearn/neighbors/tests/test_neighbors.py | 52 +++++++++++------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 46ee82036fa7c..b480847ed1f45 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -2397,39 +2397,39 @@ def _weights(dist): est.predict_proba([[1.1, 1.1]]) -def test_neighbor_classifiers_loocv(): - """Check that `predict` and related functions work fine with X=None""" - X, y = datasets.make_blobs(n_samples=500, centers=5, n_features=2, random_state=0) - - models = [ +@pytest.mark.parametrize( + "nn_model", + [ neighbors.KNeighborsClassifier(n_neighbors=10), neighbors.RadiusNeighborsClassifier(radius=5.0), - ] - - for knn in models: - loocv = cross_val_score(knn, X, y, cv=LeaveOneOut()) - - knn.fit(X, y) + ], +) +def test_neighbor_classifiers_loocv(nn_model): + """Check that `predict` and related functions work fine with X=None""" + X, y = datasets.make_blobs(n_samples=500, centers=5, n_features=2, random_state=0) - assert np.all(loocv == (knn.predict(None) == y)) - assert np.mean(loocv) == knn.score(None, y) - assert knn.score(None, y) < knn.score(X, y) + loocv = cross_val_score(nn_model, X, y, cv=LeaveOneOut()) + nn_model.fit(X, y) + assert np.all(loocv == (nn_model.predict(None) == y)) + assert np.mean(loocv) == nn_model.score(None, y) + assert nn_model.score(None, y) < nn_model.score(X, y) -def test_neighbor_regressors_loocv(): - """Check that `predict` and related functions work fine with X=None""" - X, y = datasets.load_diabetes(return_X_y=True) - models = [ +@pytest.mark.parametrize( + "nn_model", + [ neighbors.KNeighborsRegressor(n_neighbors=10), neighbors.RadiusNeighborsRegressor(radius=0.5), - ] - - for knn in models: - # Only checking cross_val_predict and not cross_val_score because - # cross_val_score does not work with LeaveOneOut() for a regressor - loocv = cross_val_predict(knn, X, y, cv=LeaveOneOut()) + ], +) +def test_neighbor_regressors_loocv(nn_model): + """Check that `predict` and related functions work fine with X=None""" + X, y = datasets.load_diabetes(return_X_y=True) - knn.fit(X, y) + # Only checking cross_val_predict and not cross_val_score because + # cross_val_score does not work with LeaveOneOut() for a regressor + loocv = cross_val_predict(nn_model, X, y, cv=LeaveOneOut()) + nn_model.fit(X, y) - assert np.all(loocv == knn.predict(None)) + assert np.all(loocv == nn_model.predict(None))