Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Performance Degradation in MeanShift When Data Has No Variance #28926

Copy link
Copy link
Closed
@akikuno

Description

@akikuno
Issue body actions

Describe the bug

When data provided to MeanShift consists of values with no variance (for example, two clusters of 0 and 1), the performance becomes extremely slow.

I am unsure whether this is a bug or an unavoidable aspect of the algorithm's design. Any clarification would be appreciated.

Steps/Code to Reproduce

import numpy as np
from sklearn.cluster import MeanShift

x = np.concatenate([np.ones(100), np.zeros(100)])
_ = MeanShift().fit_predict(x.reshape(-1, 1)) # Slow

rng = np.random.default_rng(1)
x = np.concatenate([rng.uniform(0.0, 0.001, 100), rng.uniform(0.999, 1.0, 100)])
_ = MeanShift().fit_predict(x.reshape(-1, 1)) # Fast

Link to Google Colab: https://colab.research.google.com/drive/1hlqhtaD8T40hwcleUKoI4uzrW1XtSRA4?usp=sharing#scrollTo=6g5qI45KUW_i

Expected Results

When data provided to MeanShift consists of values with no variance, the performance becomes as fast as when handling data with variance.

Actual Results

If MeanShift receives a 1D array with no variance, the computation is significantly slower.

import numpy as np
from sklearn.cluster import MeanShift

# Example where input has no variance
x = np.concatenate([np.ones(100), np.zeros(100)])
%timeit _ = MeanShift().fit_predict(x.reshape(-1, 1))
# Output: 24.9 s ± 340 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Below is a control example, where the input has some variance:

import numpy as np
from sklearn.cluster import MeanShift

# Example with minimal variance
rng = np.random.default_rng(1)
x = np.concatenate([rng.uniform(0.0, 0.001, 100), rng.uniform(0.999, 1.0, 100)])
%timeit _ = MeanShift().fit_predict(x.reshape(-1, 1))
# Output: 665 ms ± 101 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Versions

1.2.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Morty Proxy This is a proxified and sanitized view of the page, visit original site.