Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 11f9742

Browse filesBrowse files
authored
[ENH] Create public kmeans_plusplus including index output (#17937)
1 parent 0e6d415 commit 11f9742
Copy full SHA for 11f9742

File tree

7 files changed

+229
-21
lines changed
Filter options

7 files changed

+229
-21
lines changed

‎doc/modules/classes.rst

Copy file name to clipboardExpand all lines: doc/modules/classes.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ Functions
124124
cluster.dbscan
125125
cluster.estimate_bandwidth
126126
cluster.k_means
127+
cluster.kmeans_plusplus
127128
cluster.mean_shift
128129
cluster.spectral_clustering
129130
cluster.ward_tree

‎doc/modules/clustering.rst

Copy file name to clipboardExpand all lines: doc/modules/clustering.rst
+5-1Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,11 @@ initializations of the centroids. One method to help address this issue is the
197197
k-means++ initialization scheme, which has been implemented in scikit-learn
198198
(use the ``init='k-means++'`` parameter). This initializes the centroids to be
199199
(generally) distant from each other, leading to provably better results than
200-
random initialization, as shown in the reference.
200+
random initialization, as shown in the reference.
201+
202+
K-means++ can also be called independently to select seeds for other
203+
clustering algorithms, see :func:`sklearn.cluster.kmeans_plusplus` for details
204+
and example usage.
201205

202206
The algorithm supports sample weights, which can be given by a parameter
203207
``sample_weight``. This allows to assign more weight to some samples when

‎doc/whats_new/v0.24.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v0.24.rst
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ Changelog
9797
`init_size_`, are deprecated and will be removed in 0.26. :pr:`17864` by
9898
:user:`Jérémie du Boisberranger <jeremiedbb>`.
9999

100+
- |Enhancement| Added :func:`cluster.kmeans_plusplus` as public function.
101+
Initialization by KMeans++ can now be called separately to generate
102+
initial cluster centroids. :pr:`17937` by :user:`g-walsh`
103+
100104
:mod:`sklearn.compose`
101105
......................
102106

+45Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
===========================================================
3+
An example of K-Means++ initialization
4+
===========================================================
5+
6+
An example to show the output of the :func:`sklearn.cluster.kmeans_plusplus`
7+
function for generating initial seeds for clustering.
8+
9+
K-Means++ is used as the default initialization for :ref:`k_means`.
10+
11+
"""
12+
print(__doc__)
13+
14+
from sklearn.cluster import kmeans_plusplus
15+
from sklearn.datasets import make_blobs
16+
import matplotlib.pyplot as plt
17+
18+
# Generate sample data
19+
n_samples = 4000
20+
n_components = 4
21+
22+
X, y_true = make_blobs(n_samples=n_samples,
23+
centers=n_components,
24+
cluster_std=0.60,
25+
random_state=0)
26+
X = X[:, ::-1]
27+
28+
# Calculate seeds from kmeans++
29+
centers_init, indices = kmeans_plusplus(X, n_clusters=4,
30+
random_state=0)
31+
32+
# Plot init seeds along side sample data
33+
plt.figure(1)
34+
colors = ['#4EACC5', '#FF9C34', '#4E9A06', 'm']
35+
36+
for k, col in enumerate(colors):
37+
cluster_data = y_true == k
38+
plt.scatter(X[cluster_data, 0], X[cluster_data, 1],
39+
c=col, marker='.', s=10)
40+
41+
plt.scatter(centers_init[:, 0], centers_init[:, 1], c='b', s=50)
42+
plt.title("K-Means++ Initialization")
43+
plt.xticks([])
44+
plt.yticks([])
45+
plt.show()

‎sklearn/cluster/__init__.py

Copy file name to clipboardExpand all lines: sklearn/cluster/__init__.py
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ._affinity_propagation import affinity_propagation, AffinityPropagation
1010
from ._agglomerative import (ward_tree, AgglomerativeClustering,
1111
linkage_tree, FeatureAgglomeration)
12-
from ._kmeans import k_means, KMeans, MiniBatchKMeans
12+
from ._kmeans import k_means, KMeans, MiniBatchKMeans, kmeans_plusplus
1313
from ._dbscan import dbscan, DBSCAN
1414
from ._optics import (OPTICS, cluster_optics_dbscan, compute_optics_graph,
1515
cluster_optics_xi)
@@ -34,6 +34,7 @@
3434
'estimate_bandwidth',
3535
'get_bin_seeds',
3636
'k_means',
37+
'kmeans_plusplus',
3738
'linkage_tree',
3839
'mean_shift',
3940
'spectral_clustering',

‎sklearn/cluster/_kmeans.py

Copy file name to clipboardExpand all lines: sklearn/cluster/_kmeans.py
+114-18Lines changed: 114 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,15 @@
4747
# Initialization heuristic
4848

4949

50-
def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None):
51-
"""Init n_clusters seeds according to k-means++
50+
def _kmeans_plusplus(X, n_clusters, x_squared_norms,
51+
random_state, n_local_trials=None):
52+
"""Computational component for initialization of n_clusters by
53+
k-means++. Prior validation of data is assumed.
5254
5355
Parameters
5456
----------
5557
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
56-
The data to pick seeds for. To avoid memory copy, the input data
57-
should be double precision (dtype=np.float64).
58+
The data to pick seeds for.
5859
5960
n_clusters : int
6061
The number of seeds to choose.
@@ -72,35 +73,34 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None):
7273
Set to None to make the number of trials depend logarithmically
7374
on the number of seeds (2+log(k)); this is the default.
7475
75-
Notes
76-
-----
77-
Selects initial cluster centers for k-mean clustering in a smart way
78-
to speed up convergence. see: Arthur, D. and Vassilvitskii, S.
79-
"k-means++: the advantages of careful seeding". ACM-SIAM symposium
80-
on Discrete algorithms. 2007
76+
Returns
77+
-------
78+
centers : ndarray of shape (n_clusters, n_features)
79+
The inital centers for k-means.
8180
82-
Version ported from http://www.stanford.edu/~darthur/kMeansppTest.zip,
83-
which is the implementation used in the aforementioned paper.
81+
indices : ndarray of shape (n_clusters,)
82+
The index location of the chosen centers in the data array X. For a
83+
given index and center, X[index] = center.
8484
"""
8585
n_samples, n_features = X.shape
8686

8787
centers = np.empty((n_clusters, n_features), dtype=X.dtype)
8888

89-
assert x_squared_norms is not None, 'x_squared_norms None in _k_init'
90-
9189
# Set the number of local seeding trials if none is given
9290
if n_local_trials is None:
9391
# This is what Arthur/Vassilvitskii tried, but did not report
9492
# specific results for other than mentioning in the conclusion
9593
# that it helped.
9694
n_local_trials = 2 + int(np.log(n_clusters))
9795

98-
# Pick first center randomly
96+
# Pick first center randomly and track index of point
9997
center_id = random_state.randint(n_samples)
98+
indices = np.full(n_clusters, -1, dtype=int)
10099
if sp.issparse(X):
101100
centers[0] = X[center_id].toarray()
102101
else:
103102
centers[0] = X[center_id]
103+
indices[0] = center_id
104104

105105
# Initialize list of closest distances and calculate current potential
106106
closest_dist_sq = euclidean_distances(
@@ -139,8 +139,9 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None):
139139
centers[c] = X[best_candidate].toarray()
140140
else:
141141
centers[c] = X[best_candidate]
142+
indices[c] = best_candidate
142143

143-
return centers
144+
return centers, indices
144145

145146

146147
###############################################################################
@@ -936,8 +937,9 @@ def _init_centroids(self, X, x_squared_norms, init, random_state,
936937
n_samples = X.shape[0]
937938

938939
if isinstance(init, str) and init == 'k-means++':
939-
centers = _k_init(X, n_clusters, random_state=random_state,
940-
x_squared_norms=x_squared_norms)
940+
centers, _ = _kmeans_plusplus(X, n_clusters,
941+
random_state=random_state,
942+
x_squared_norms=x_squared_norms)
941943
elif isinstance(init, str) and init == 'random':
942944
seeds = random_state.permutation(n_samples)[:n_clusters]
943945
centers = X[seeds]
@@ -1925,3 +1927,97 @@ def _more_tags(self):
19251927
'zero sample_weight is not equivalent to removing samples',
19261928
}
19271929
}
1930+
1931+
1932+
def kmeans_plusplus(X, n_clusters, *, x_squared_norms=None,
1933+
random_state=None, n_local_trials=None):
1934+
"""Init n_clusters seeds according to k-means++
1935+
1936+
.. versionadded:: 0.24
1937+
1938+
Parameters
1939+
----------
1940+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
1941+
The data to pick seeds from.
1942+
1943+
n_clusters : int
1944+
The number of centroids to initialize
1945+
1946+
x_squared_norms : array-like of shape (n_samples,), default=None
1947+
Squared Euclidean norm of each data point.
1948+
1949+
random_state : int or RandomState instance, default=None
1950+
Determines random number generation for centroid initialization. Pass
1951+
an int for reproducible output across multiple function calls.
1952+
See :term:`Glossary <random_state>`.
1953+
1954+
n_local_trials : int, default=None
1955+
The number of seeding trials for each center (except the first),
1956+
of which the one reducing inertia the most is greedily chosen.
1957+
Set to None to make the number of trials depend logarithmically
1958+
on the number of seeds (2+log(k)).
1959+
1960+
Returns
1961+
-------
1962+
centers : ndarray of shape (n_clusters, n_features)
1963+
The inital centers for k-means.
1964+
1965+
indices : ndarray of shape (n_clusters,)
1966+
The index location of the chosen centers in the data array X. For a
1967+
given index and center, X[index] = center.
1968+
1969+
Notes
1970+
-----
1971+
Selects initial cluster centers for k-mean clustering in a smart way
1972+
to speed up convergence. see: Arthur, D. and Vassilvitskii, S.
1973+
"k-means++: the advantages of careful seeding". ACM-SIAM symposium
1974+
on Discrete algorithms. 2007
1975+
1976+
Examples
1977+
--------
1978+
1979+
>>> from sklearn.cluster import kmeans_plusplus
1980+
>>> import numpy as np
1981+
>>> X = np.array([[1, 2], [1, 4], [1, 0],
1982+
... [10, 2], [10, 4], [10, 0]])
1983+
>>> centers, indices = kmeans_plusplus(X, n_clusters=2, random_state=0)
1984+
>>> centers
1985+
array([[10, 4],
1986+
[ 1, 0]])
1987+
>>> indices
1988+
array([4, 2])
1989+
"""
1990+
1991+
# Check data
1992+
check_array(X, accept_sparse='csr',
1993+
dtype=[np.float64, np.float32])
1994+
1995+
if X.shape[0] < n_clusters:
1996+
raise ValueError(f"n_samples={X.shape[0]} should be >= "
1997+
f"n_clusters={n_clusters}.")
1998+
1999+
# Check parameters
2000+
if x_squared_norms is None:
2001+
x_squared_norms = row_norms(X, squared=True)
2002+
else:
2003+
x_squared_norms = check_array(x_squared_norms,
2004+
dtype=X.dtype,
2005+
ensure_2d=False)
2006+
2007+
if x_squared_norms.shape[0] != X.shape[0]:
2008+
raise ValueError(
2009+
f"The length of x_squared_norms {x_squared_norms.shape[0]} should "
2010+
f"be equal to the length of n_samples {X.shape[0]}.")
2011+
2012+
if n_local_trials is not None and n_local_trials < 1:
2013+
raise ValueError(
2014+
f"n_local_trials is set to {n_local_trials} but should be an "
2015+
f"integer value greater than zero.")
2016+
2017+
random_state = check_random_state(random_state)
2018+
2019+
# Call private k-means++
2020+
centers, indices = _kmeans_plusplus(X, n_clusters, x_squared_norms,
2021+
random_state, n_local_trials)
2022+
2023+
return centers, indices

‎sklearn/cluster/tests/test_k_means.py

Copy file name to clipboardExpand all lines: sklearn/cluster/tests/test_k_means.py
+58-1Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sklearn.metrics import pairwise_distances
2121
from sklearn.metrics import pairwise_distances_argmin
2222
from sklearn.metrics.cluster import v_measure_score
23-
from sklearn.cluster import KMeans, k_means
23+
from sklearn.cluster import KMeans, k_means, kmeans_plusplus
2424
from sklearn.cluster import MiniBatchKMeans
2525
from sklearn.cluster._kmeans import _labels_inertia
2626
from sklearn.cluster._kmeans import _mini_batch_step
@@ -1030,3 +1030,60 @@ def test_minibatch_kmeans_wrong_params(param, match):
10301030
# are passed for the MiniBatchKMeans specific parameters
10311031
with pytest.raises(ValueError, match=match):
10321032
MiniBatchKMeans(**param).fit(X)
1033+
1034+
1035+
@pytest.mark.parametrize("param, match", [
1036+
({"n_local_trials": 0},
1037+
r"n_local_trials is set to 0 but should be an "
1038+
r"integer value greater than zero"),
1039+
({"x_squared_norms": X[:2]},
1040+
r"The length of x_squared_norms .* should "
1041+
r"be equal to the length of n_samples")]
1042+
)
1043+
def test_kmeans_plusplus_wrong_params(param, match):
1044+
with pytest.raises(ValueError, match=match):
1045+
kmeans_plusplus(X, n_clusters, **param)
1046+
1047+
1048+
@pytest.mark.parametrize("data", [X, X_csr])
1049+
@pytest.mark.parametrize("dtype", [np.float64, np.float32])
1050+
def test_kmeans_plusplus_output(data, dtype):
1051+
# Check for the correct number of seeds and all positive values
1052+
data = data.astype(dtype)
1053+
centers, indices = kmeans_plusplus(data, n_clusters)
1054+
1055+
# Check there are the correct number of indices and that all indices are
1056+
# positive and within the number of samples
1057+
assert indices.shape[0] == n_clusters
1058+
assert (indices >= 0).all()
1059+
assert (indices <= data.shape[0]).all()
1060+
1061+
# Check for the correct number of seeds and that they are bound by the data
1062+
assert centers.shape[0] == n_clusters
1063+
assert (centers.max(axis=0) <= data.max(axis=0)).all()
1064+
assert (centers.min(axis=0) >= data.min(axis=0)).all()
1065+
1066+
# Check that indices correspond to reported centers
1067+
# Use X for comparison rather than data, test still works against centers
1068+
# calculated with sparse data.
1069+
assert_allclose(X[indices].astype(dtype), centers)
1070+
1071+
1072+
@pytest.mark.parametrize("x_squared_norms", [row_norms(X, squared=True), None])
1073+
def test_kmeans_plusplus_norms(x_squared_norms):
1074+
# Check that defining x_squared_norms returns the same as default=None.
1075+
centers, indices = kmeans_plusplus(X, n_clusters,
1076+
x_squared_norms=x_squared_norms)
1077+
1078+
assert_allclose(X[indices], centers)
1079+
1080+
1081+
def test_kmeans_plusplus_dataorder():
1082+
# Check that memory layout does not effect result
1083+
centers_c, _ = kmeans_plusplus(X, n_clusters, random_state=0)
1084+
1085+
X_fortran = np.asfortranarray(X)
1086+
1087+
centers_fortran, _ = kmeans_plusplus(X_fortran, n_clusters, random_state=0)
1088+
1089+
assert_allclose(centers_c, centers_fortran)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.