Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ac52198

Browse filesBrowse files
akikunoogriseljeremiedbb
committed
FIX convergence criterion of MeanShift (scikit-learn#28951)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent 601e8f8 commit ac52198
Copy full SHA for ac52198

File tree

3 files changed

+13
-1
lines changed
Filter options

3 files changed

+13
-1
lines changed

‎doc/whats_new/v1.5.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.5.rst
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ Changelog
183183
:mod:`sklearn.cluster`
184184
......................
185185

186+
- |Fix| The :class:`cluster.MeanShift` class now properly converges for constant data.
187+
:pr:`28951` by :user:`Akihiro Kuno <akikuno>`.
188+
186189
- |FIX| Create copy of precomputed sparse matrix within the `fit` method of
187190
:class:`~cluster.OPTICS` to avoid in-place modification of the sparse matrix.
188191
:pr:`28491` by :user:`Thanh Lam Dang <lamdang2k>`.

‎sklearn/cluster/_mean_shift.py

Copy file name to clipboardExpand all lines: sklearn/cluster/_mean_shift.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter):
122122
my_mean = np.mean(points_within, axis=0)
123123
# If converged or at max_iter, adds the cluster
124124
if (
125-
np.linalg.norm(my_mean - my_old_mean) < stop_thresh
125+
np.linalg.norm(my_mean - my_old_mean) <= stop_thresh
126126
or completed_iterations == max_iter
127127
):
128128
break

‎sklearn/cluster/tests/test_mean_shift.py

Copy file name to clipboardExpand all lines: sklearn/cluster/tests/test_mean_shift.py
+9Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@
2525
)
2626

2727

28+
def test_convergence_of_1d_constant_data():
29+
# Test convergence using 1D constant data
30+
# Non-regression test for:
31+
# https://github.com/scikit-learn/scikit-learn/issues/28926
32+
model = MeanShift()
33+
n_iter = model.fit(np.ones(10).reshape(-1, 1)).n_iter_
34+
assert n_iter < model.max_iter
35+
36+
2837
def test_estimate_bandwidth():
2938
# Test estimate_bandwidth
3039
bandwidth = estimate_bandwidth(X, n_samples=200)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.