Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 15d5a06

Browse filesBrowse files
authored
TST Improve tests for neighbor models with X=None (#30101)
1 parent 6eb2ef3 commit 15d5a06
Copy full SHA for 15d5a06

File tree

Expand file treeCollapse file tree

1 file changed

+50
-12
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+50
-12
lines changed

‎sklearn/neighbors/tests/test_neighbors.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/tests/test_neighbors.py
+50-12Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,35 +2401,73 @@ def _weights(dist):
24012401
"nn_model",
24022402
[
24032403
neighbors.KNeighborsClassifier(n_neighbors=10),
2404-
neighbors.RadiusNeighborsClassifier(radius=5.0),
2404+
neighbors.RadiusNeighborsClassifier(),
24052405
],
24062406
)
2407-
def test_neighbor_classifiers_loocv(nn_model):
2408-
"""Check that `predict` and related functions work fine with X=None"""
2409-
X, y = datasets.make_blobs(n_samples=500, centers=5, n_features=2, random_state=0)
2407+
@pytest.mark.parametrize("algorithm", ALGORITHMS)
2408+
def test_neighbor_classifiers_loocv(nn_model, algorithm):
2409+
"""Check that `predict` and related functions work fine with X=None
2410+
2411+
Calling predict with X=None computes a prediction for each training point
2412+
from the labels of its neighbors (without the label of the data point being
2413+
predicted upon). This is therefore mathematically equivalent to
2414+
leave-one-out cross-validation without having do any retraining (rebuilding
2415+
a KD-tree or Ball-tree index) or any data reshuffling.
2416+
"""
2417+
X, y = datasets.make_blobs(n_samples=15, centers=5, n_features=2, random_state=0)
2418+
2419+
nn_model = clone(nn_model).set_params(algorithm=algorithm)
2420+
2421+
# Set the radius for RadiusNeighborsRegressor to some percentile of the
2422+
# empirical pairwise distances to avoid trivial test cases and warnings for
2423+
# predictions with no neighbors within the radius.
2424+
if "radius" in nn_model.get_params():
2425+
dists = pairwise_distances(X).ravel()
2426+
dists = dists[dists > 0]
2427+
nn_model.set_params(radius=np.percentile(dists, 80))
24102428

24112429
loocv = cross_val_score(nn_model, X, y, cv=LeaveOneOut())
24122430
nn_model.fit(X, y)
24132431

2414-
assert np.all(loocv == (nn_model.predict(None) == y))
2415-
assert np.mean(loocv) == nn_model.score(None, y)
2432+
assert_allclose(loocv, nn_model.predict(None) == y)
2433+
assert np.mean(loocv) == pytest.approx(nn_model.score(None, y))
2434+
2435+
# Evaluating `nn_model` on its "training" set should lead to a higher
2436+
# accuracy value than leaving out each data point in turn because the
2437+
# former can overfit while the latter cannot by construction.
24162438
assert nn_model.score(None, y) < nn_model.score(X, y)
24172439

24182440

24192441
@pytest.mark.parametrize(
24202442
"nn_model",
24212443
[
24222444
neighbors.KNeighborsRegressor(n_neighbors=10),
2423-
neighbors.RadiusNeighborsRegressor(radius=0.5),
2445+
neighbors.RadiusNeighborsRegressor(),
24242446
],
24252447
)
2426-
def test_neighbor_regressors_loocv(nn_model):
2448+
@pytest.mark.parametrize("algorithm", ALGORITHMS)
2449+
def test_neighbor_regressors_loocv(nn_model, algorithm):
24272450
"""Check that `predict` and related functions work fine with X=None"""
2428-
X, y = datasets.load_diabetes(return_X_y=True)
2451+
X, y = datasets.make_regression(n_samples=15, n_features=2, random_state=0)
24292452

24302453
# Only checking cross_val_predict and not cross_val_score because
2431-
# cross_val_score does not work with LeaveOneOut() for a regressor
2454+
# cross_val_score does not work with LeaveOneOut() for a regressor: the
2455+
# default score method implements R2 score which is not well defined for a
2456+
# single data point.
2457+
#
2458+
# TODO: if score is refactored to evaluate models for other scoring
2459+
# functions, then this test can be extended to check cross_val_score as
2460+
# well.
2461+
nn_model = clone(nn_model).set_params(algorithm=algorithm)
2462+
2463+
# Set the radius for RadiusNeighborsRegressor to some percentile of the
2464+
# empirical pairwise distances to avoid trivial test cases and warnings for
2465+
# predictions with no neighbors within the radius.
2466+
if "radius" in nn_model.get_params():
2467+
dists = pairwise_distances(X).ravel()
2468+
dists = dists[dists > 0]
2469+
nn_model.set_params(radius=np.percentile(dists, 80))
2470+
24322471
loocv = cross_val_predict(nn_model, X, y, cv=LeaveOneOut())
24332472
nn_model.fit(X, y)
2434-
2435-
assert np.all(loocv == nn_model.predict(None))
2473+
assert_allclose(loocv, nn_model.predict(None))

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.