-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
MNT Refactor center initialization in KMeans #17928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MNT Refactor center initialization in KMeans #17928
Conversation
sklearn/cluster/_kmeans.py
Outdated
Squared euclidean norm of each data point. Pass it if you have it | ||
at hands already to avoid it being recomputed here. | ||
|
||
init : {'k-means++', 'random', ndarray, callable} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init : {'k-means++', 'random', ndarray, callable} | |
init : {'k-means++', 'random'}, callable or ndarray of shape (n_clusters, n_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
sklearn/cluster/_kmeans.py
Outdated
|
||
Returns | ||
------- | ||
centers : ndarray of shape(n_clusters, n_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
centers : ndarray of shape(n_clusters, n_features) | |
centers : ndarray of shape (n_clusters, n_features) |
@@ -624,7 +554,6 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, | ||
|
||
n_threads = _openmp_effective_n_threads(n_threads) | ||
|
||
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been done in the main method (fit, predict, etc.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. The goal is to do all validations in the fit, predict, etc methods and have all private helpers assume validation is already done, to avoid duplicating validation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* refactor center init in KMeans * address comments
Make the center initialization in a single place, i.e in
fit
.Avoid repeated validations.
Make
_init_centroids
a method KMeans to make cleaner use of KMeans attributes.Extracted from #17622 to facilitate the reviews.
ping @glemaitre