Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 437b9ea

Browse filesBrowse files
ogriselglemaitre
authored andcommitted
FIX check_decision_proba_consistency random failure (scikit-learn#19225)
* FIX more deterministic check_decision_proba_consistency * Trigger [cd build] * Re-add rounding * Trigger [cd build] * Avoid redundant phrasing in comment [ci skip]
1 parent aff72a0 commit 437b9ea
Copy full SHA for 437b9ea

File tree

Expand file treeCollapse file tree

1 file changed

+6
-3
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

1 file changed

+6
-3
lines changed
Open diff view settings
Collapse file

‎sklearn/utils/estimator_checks.py‎

Copy file name to clipboardExpand all lines: sklearn/utils/estimator_checks.py
+6-3Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2907,16 +2907,19 @@ def check_decision_proba_consistency(name, estimator_orig):
29072907
centers = [(2, 2), (4, 4)]
29082908
X, y = make_blobs(n_samples=100, random_state=0, n_features=4,
29092909
centers=centers, cluster_std=1.0, shuffle=True)
2910-
X_test = np.random.randn(20, 2) + 4
2910+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
2911+
random_state=0)
29112912
estimator = clone(estimator_orig)
29122913

29132914
if (hasattr(estimator, "decision_function") and
29142915
hasattr(estimator, "predict_proba")):
29152916

2916-
estimator.fit(X, y)
2917+
estimator.fit(X_train, y_train)
29172918
# Since the link function from decision_function() to predict_proba()
29182919
# is sometimes not precise enough (typically expit), we round to the
2919-
# 10th decimal to avoid numerical issues.
2920+
# 10th decimal to avoid numerical issues: we compare the rank
2921+
# with deterministic ties rather than get platform specific rank
2922+
# inversions in case of machine level differences.
29202923
a = estimator.predict_proba(X_test)[:, 1].round(decimals=10)
29212924
b = estimator.decision_function(X_test).round(decimals=10)
29222925
assert_array_equal(rankdata(a), rankdata(b))

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.