Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8c4589b

Browse filesBrowse files
authored
ENH Scalable MiniBatchKMeans plus cln / fixes / refactoring (#17622)
1 parent 138da7e commit 8c4589b
Copy full SHA for 8c4589b

File tree

9 files changed

+736
-579
lines changed
Filter options

9 files changed

+736
-579
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+20-2Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,30 @@ Changelog
9595
in multicore settings. :pr:`19052` by
9696
:user:`Yusuke Nagasaka <YusukeNagasaka>`.
9797

98-
- |API| :class:`cluster.Birch` attributes, `fit_` and `partial_fit_`, are
99-
deprecated and will be removed in 1.2. :pr:`19297` by `Thomas Fan`_.
98+
- |Efficiency| :class:`cluster.MiniBatchKMeans` is now faster in multicore
99+
settings. :pr:`17622` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
100+
101+
- |Fix| Fixed a bug in :class:`cluster.MiniBatchKMeans` where the sample
102+
weights were partially ignored when the input is sparse. :pr:`17622` by
103+
:user:`Jérémie du Boisberranger <jeremiedbb>`.
100104

105+
- |Fix| Improved convergence detection based on center change in
106+
:class:`cluster.MiniBatchKMeans` which was almost never achievable.
107+
:pr:`17622` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
108+
101109
- |FIX| :class:`cluster.AgglomerativeClustering` now supports readonly
102110
memory-mapped datasets. :pr:`19883` by `Julien Jerphanion <jjerphan>`.
103111

112+
- |API| :class:`cluster.Birch` attributes, `fit_` and `partial_fit_`, are
113+
deprecated and will be removed in 1.2. :pr:`19297` by `Thomas Fan`_.
114+
115+
- |API| the default value for the `batch_size` parameter of
116+
:class:`MiniBatchKMeans` was changed from 100 to 1024 due to efficiency
117+
reasons. The `n_iter_` attribute of :class:`MiniBatchKMeans` now reports the
118+
number of started epochs and the `n_steps_` attribute reports the number of
119+
mini batches processed. :pr:`17622`
120+
by :user:`Jérémie du Boisberranger <jeremiedbb>`.
121+
104122
:mod:`sklearn.compose`
105123
......................
106124

‎sklearn/cluster/_k_means_fast.pyx renamed to ‎sklearn/cluster/_k_means_common.pyx

Copy file name to clipboardExpand all lines: sklearn/cluster/_k_means_common.pyx
+9-109Lines changed: 9 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,15 @@
1414

1515
import numpy as np
1616
cimport numpy as np
17-
cimport cython
1817
from cython cimport floating
18+
from cython.parallel cimport prange
1919
from libc.math cimport sqrt
2020

2121
from ..utils.extmath import row_norms
2222

2323

2424
np.import_array()
2525

26-
ctypedef np.float64_t DOUBLE
27-
ctypedef np.int32_t INT
28-
2926

3027
# Number of samples per data chunk defined as a global constant.
3128
CHUNK_SIZE = 256
@@ -103,7 +100,8 @@ cpdef floating _inertia_dense(
103100
np.ndarray[floating, ndim=2, mode='c'] X, # IN
104101
floating[::1] sample_weight, # IN
105102
floating[:, ::1] centers, # IN
106-
int[::1] labels): # IN
103+
int[::1] labels, # IN
104+
int n_threads):
107105
"""Compute inertia for dense input data
108106
109107
Sum of squared distance between each sample and its assigned center.
@@ -116,7 +114,8 @@ cpdef floating _inertia_dense(
116114
floating sq_dist = 0.0
117115
floating inertia = 0.0
118116

119-
for i in range(n_samples):
117+
for i in prange(n_samples, nogil=True, num_threads=n_threads,
118+
schedule='static'):
120119
j = labels[i]
121120
sq_dist = _euclidean_dense_dense(&X[i, 0], &centers[j, 0],
122121
n_features, True)
@@ -129,7 +128,8 @@ cpdef floating _inertia_sparse(
129128
X, # IN
130129
floating[::1] sample_weight, # IN
131130
floating[:, ::1] centers, # IN
132-
int[::1] labels): # IN
131+
int[::1] labels, # IN
132+
int n_threads):
133133
"""Compute inertia for sparse input data
134134
135135
Sum of squared distance between each sample and its assigned center.
@@ -148,7 +148,8 @@ cpdef floating _inertia_sparse(
148148

149149
floating[::1] centers_squared_norms = row_norms(centers, squared=True)
150150

151-
for i in range(n_samples):
151+
for i in prange(n_samples, nogil=True, num_threads=n_threads,
152+
schedule='static'):
152153
j = labels[i]
153154
sq_dist = _euclidean_sparse_dense(
154155
X_data[X_indptr[i]: X_indptr[i + 1]],
@@ -286,104 +287,3 @@ cdef void _center_shift(
286287
for j in range(n_clusters):
287288
center_shift[j] = _euclidean_dense_dense(
288289
&centers_new[j, 0], &centers_old[j, 0], n_features, False)
289-
290-
291-
def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weight,
292-
np.ndarray[floating, ndim=1] x_squared_norms,
293-
np.ndarray[floating, ndim=2] centers,
294-
np.ndarray[floating, ndim=1] weight_sums,
295-
np.ndarray[INT, ndim=1] nearest_center,
296-
np.ndarray[floating, ndim=1] old_center,
297-
int compute_squared_diff):
298-
"""Incremental update of the centers for sparse MiniBatchKMeans.
299-
300-
Parameters
301-
----------
302-
303-
X : CSR matrix, dtype float
304-
The complete (pre allocated) training set as a CSR matrix.
305-
306-
centers : array, shape (n_clusters, n_features)
307-
The cluster centers
308-
309-
counts : array, shape (n_clusters,)
310-
The vector in which we keep track of the numbers of elements in a
311-
cluster
312-
313-
Returns
314-
-------
315-
inertia : float
316-
The inertia of the batch prior to centers update, i.e. the sum
317-
of squared distances to the closest center for each sample. This
318-
is the objective function being minimized by the k-means algorithm.
319-
320-
squared_diff : float
321-
The sum of squared update (squared norm of the centers position
322-
change). If compute_squared_diff is 0, this computation is skipped and
323-
0.0 is returned instead.
324-
325-
Both squared diff and inertia are commonly used to monitor the convergence
326-
of the algorithm.
327-
"""
328-
cdef:
329-
np.ndarray[floating, ndim=1] X_data = X.data
330-
np.ndarray[int, ndim=1] X_indices = X.indices
331-
np.ndarray[int, ndim=1] X_indptr = X.indptr
332-
unsigned int n_samples = X.shape[0]
333-
unsigned int n_clusters = centers.shape[0]
334-
unsigned int n_features = centers.shape[1]
335-
336-
unsigned int sample_idx, center_idx, feature_idx
337-
unsigned int k
338-
DOUBLE old_weight_sum, new_weight_sum
339-
DOUBLE center_diff
340-
DOUBLE squared_diff = 0.0
341-
342-
# move centers to the mean of both old and newly assigned samples
343-
for center_idx in range(n_clusters):
344-
old_weight_sum = weight_sums[center_idx]
345-
new_weight_sum = old_weight_sum
346-
347-
# count the number of samples assigned to this center
348-
for sample_idx in range(n_samples):
349-
if nearest_center[sample_idx] == center_idx:
350-
new_weight_sum += sample_weight[sample_idx]
351-
352-
if new_weight_sum == old_weight_sum:
353-
# no new sample: leave this center as it stands
354-
continue
355-
356-
# rescale the old center to reflect it previous accumulated weight
357-
# with regards to the new data that will be incrementally contributed
358-
if compute_squared_diff:
359-
old_center[:] = centers[center_idx]
360-
centers[center_idx] *= old_weight_sum
361-
362-
# iterate of over samples assigned to this cluster to move the center
363-
# location by inplace summation
364-
for sample_idx in range(n_samples):
365-
if nearest_center[sample_idx] != center_idx:
366-
continue
367-
368-
# inplace sum with new samples that are members of this cluster
369-
# and update of the incremental squared difference update of the
370-
# center position
371-
for k in range(X_indptr[sample_idx], X_indptr[sample_idx + 1]):
372-
centers[center_idx, X_indices[k]] += X_data[k]
373-
374-
# inplace rescale center with updated count
375-
if new_weight_sum > old_weight_sum:
376-
# update the count statistics for this center
377-
weight_sums[center_idx] = new_weight_sum
378-
379-
# re-scale the updated center with the total new counts
380-
centers[center_idx] /= new_weight_sum
381-
382-
# update the incremental computation of the squared total
383-
# centers position change
384-
if compute_squared_diff:
385-
for feature_idx in range(n_features):
386-
squared_diff += (old_center[feature_idx]
387-
- centers[center_idx, feature_idx]) ** 2
388-
389-
return squared_diff

‎sklearn/cluster/_k_means_elkan.pyx

Copy file name to clipboardExpand all lines: sklearn/cluster/_k_means_elkan.pyx
+7-7Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ from libc.stdlib cimport calloc, free
1818
from libc.string cimport memset, memcpy
1919

2020
from ..utils.extmath import row_norms
21-
from ._k_means_fast import CHUNK_SIZE
22-
from ._k_means_fast cimport _relocate_empty_clusters_dense
23-
from ._k_means_fast cimport _relocate_empty_clusters_sparse
24-
from ._k_means_fast cimport _euclidean_dense_dense
25-
from ._k_means_fast cimport _euclidean_sparse_dense
26-
from ._k_means_fast cimport _average_centers
27-
from ._k_means_fast cimport _center_shift
21+
from ._k_means_common import CHUNK_SIZE
22+
from ._k_means_common cimport _relocate_empty_clusters_dense
23+
from ._k_means_common cimport _relocate_empty_clusters_sparse
24+
from ._k_means_common cimport _euclidean_dense_dense
25+
from ._k_means_common cimport _euclidean_sparse_dense
26+
from ._k_means_common cimport _average_centers
27+
from ._k_means_common cimport _center_shift
2828

2929

3030
np.import_array()

‎sklearn/cluster/_k_means_lloyd.pyx

Copy file name to clipboardExpand all lines: sklearn/cluster/_k_means_lloyd.pyx
+5-5Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ cimport numpy as np
1111
from cython cimport floating
1212
from cython.parallel import prange, parallel
1313
from libc.stdlib cimport malloc, calloc, free
14-
from libc.string cimport memset, memcpy
14+
from libc.string cimport memset
1515
from libc.float cimport DBL_MAX, FLT_MAX
1616

1717
from ..utils.extmath import row_norms
1818
from ..utils._cython_blas cimport _gemm
1919
from ..utils._cython_blas cimport RowMajor, Trans, NoTrans
20-
from ._k_means_fast import CHUNK_SIZE
21-
from ._k_means_fast cimport _relocate_empty_clusters_dense
22-
from ._k_means_fast cimport _relocate_empty_clusters_sparse
23-
from ._k_means_fast cimport _average_centers, _center_shift
20+
from ._k_means_common import CHUNK_SIZE
21+
from ._k_means_common cimport _relocate_empty_clusters_dense
22+
from ._k_means_common cimport _relocate_empty_clusters_sparse
23+
from ._k_means_common cimport _average_centers, _center_shift
2424

2525

2626
np.import_array()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.