Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 092caed

Browse filesBrowse files
authored
PERF revert openmp use in csr_row_norms (scikit-learn#26275)
1 parent 523c135 commit 092caed
Copy full SHA for 092caed

File tree

Expand file treeCollapse file tree

2 files changed

+24
-12
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+24
-12
lines changed

‎sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp

Copy file name to clipboardExpand all lines: sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp
+19-3Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from libcpp.vector cimport vector
55

66
from ...utils._cython_blas cimport _dot
77
from ...utils._openmp_helpers cimport omp_get_thread_num
8-
from ...utils._typedefs cimport intp_t, float32_t, float64_t
8+
from ...utils._typedefs cimport intp_t, float32_t, float64_t, int32_t
99

1010
import numpy as np
1111

@@ -14,7 +14,6 @@ from numbers import Integral
1414
from sklearn import get_config
1515
from sklearn.utils import check_scalar
1616
from ...utils._openmp_helpers import _openmp_effective_n_threads
17-
from ...utils.sparsefuncs_fast import _sqeuclidean_row_norms_sparse
1817

1918
#####################
2019

@@ -84,6 +83,23 @@ cdef float64_t[::1] _sqeuclidean_row_norms32_dense(
8483
return squared_row_norms
8584

8685

86+
cdef float64_t[::1] _sqeuclidean_row_norms64_sparse(
87+
const float64_t[:] X_data,
88+
const int32_t[:] X_indptr,
89+
intp_t num_threads,
90+
):
91+
cdef:
92+
intp_t n = X_indptr.shape[0] - 1
93+
int32_t X_i_ptr, idx = 0
94+
float64_t[::1] squared_row_norms = np.zeros(n, dtype=np.float64)
95+
96+
for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads):
97+
for X_i_ptr in range(X_indptr[idx], X_indptr[idx+1]):
98+
squared_row_norms[idx] += X_data[X_i_ptr] * X_data[X_i_ptr]
99+
100+
return squared_row_norms
101+
102+
87103
{{for name_suffix in ["64", "32"]}}
88104

89105
from ._datasets_pair cimport DatasetsPair{{name_suffix}}
@@ -98,7 +114,7 @@ cpdef float64_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
98114
# by moving squared row norms computations in MiddleTermComputer.
99115
X_data = np.asarray(X.data, dtype=np.float64)
100116
X_indptr = np.asarray(X.indptr, dtype=np.int32)
101-
return _sqeuclidean_row_norms_sparse(X_data, X_indptr, num_threads)
117+
return _sqeuclidean_row_norms64_sparse(X_data, X_indptr, num_threads)
102118
else:
103119
return _sqeuclidean_row_norms{{name_suffix}}_dense(X, num_threads)
104120

‎sklearn/utils/sparsefuncs_fast.pyx

Copy file name to clipboardExpand all lines: sklearn/utils/sparsefuncs_fast.pyx
+5-9Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ from libc.math cimport fabs, sqrt, isnan
1111
cimport numpy as cnp
1212
import numpy as np
1313
from cython cimport floating
14-
from cython.parallel cimport prange
15-
16-
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
1714

1815
cnp.import_array()
1916

@@ -28,14 +25,12 @@ def csr_row_norms(X):
2825
"""Squared L2 norm of each row in CSR matrix X."""
2926
if X.dtype not in [np.float32, np.float64]:
3027
X = X.astype(np.float64)
31-
n_threads = _openmp_effective_n_threads()
32-
return _sqeuclidean_row_norms_sparse(X.data, X.indptr, n_threads)
28+
return _sqeuclidean_row_norms_sparse(X.data, X.indptr)
3329

3430

3531
def _sqeuclidean_row_norms_sparse(
3632
const floating[::1] X_data,
3733
const integral[::1] X_indptr,
38-
int n_threads,
3934
):
4035
cdef:
4136
integral n_samples = X_indptr.shape[0] - 1
@@ -45,9 +40,10 @@ def _sqeuclidean_row_norms_sparse(
4540

4641
cdef floating[::1] squared_row_norms = np.zeros(n_samples, dtype=dtype)
4742

48-
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
49-
for j in range(X_indptr[i], X_indptr[i + 1]):
50-
squared_row_norms[i] += X_data[j] * X_data[j]
43+
with nogil:
44+
for i in range(n_samples):
45+
for j in range(X_indptr[i], X_indptr[i + 1]):
46+
squared_row_norms[i] += X_data[j] * X_data[j]
5147

5248
return np.asarray(squared_row_norms)
5349

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.