Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit de67a44

Browse filesBrowse files
ArturoAmorQjeremiedbbVincent-Maladiere
authored
ENH Let csr_row_norms support multi-thread (#25598)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Vincent M <maladiere.vincent@yahoo.fr>
1 parent ae4a1b1 commit de67a44
Copy full SHA for de67a44

File tree

1 file changed

+12
-9
lines changed
Filter options

1 file changed

+12
-9
lines changed

‎sklearn/utils/sparsefuncs_fast.pyx

Copy file name to clipboardExpand all lines: sklearn/utils/sparsefuncs_fast.pyx
+12-9Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ from libc.math cimport fabs, sqrt
1212
cimport numpy as cnp
1313
import numpy as np
1414
from cython cimport floating
15+
from cython.parallel cimport prange
1516
from numpy.math cimport isnan
1617

18+
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
19+
1720
cnp.import_array()
1821

1922
ctypedef fused integral:
@@ -27,13 +30,14 @@ def csr_row_norms(X):
2730
"""Squared L2 norm of each row in CSR matrix X."""
2831
if X.dtype not in [np.float32, np.float64]:
2932
X = X.astype(np.float64)
30-
return _csr_row_norms(X.data, X.indices, X.indptr)
33+
n_threads = _openmp_effective_n_threads()
34+
return _sqeuclidean_row_norms_sparse(X.data, X.indptr, n_threads)
3135

3236

33-
def _csr_row_norms(
37+
def _sqeuclidean_row_norms_sparse(
3438
const floating[::1] X_data,
35-
const integral[::1] X_indices,
3639
const integral[::1] X_indptr,
40+
int n_threads,
3741
):
3842
cdef:
3943
integral n_samples = X_indptr.shape[0] - 1
@@ -42,14 +46,13 @@ def _csr_row_norms(
4246

4347
dtype = np.float32 if floating is float else np.float64
4448

45-
cdef floating[::1] norms = np.zeros(n_samples, dtype=dtype)
49+
cdef floating[::1] squared_row_norms = np.zeros(n_samples, dtype=dtype)
4650

47-
with nogil:
48-
for i in range(n_samples):
49-
for j in range(X_indptr[i], X_indptr[i + 1]):
50-
norms[i] += X_data[j] * X_data[j]
51+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
52+
for j in range(X_indptr[i], X_indptr[i + 1]):
53+
squared_row_norms[i] += X_data[j] * X_data[j]
5154

52-
return np.asarray(norms)
55+
return np.asarray(squared_row_norms)
5356

5457

5558
def csr_mean_variance_axis0(X, weights=None, return_sum_weights=False):

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.