Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 57d3668

Browse filesBrowse files
authored
MNT Avoid catastrophic cancellation in mean_variance_axis (#19766)
1 parent 54ff7b7 commit 57d3668
Copy full SHA for 57d3668

File tree

3 files changed

+71
-22
lines changed
Filter options

3 files changed

+71
-22
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+9-2Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,19 @@ Changelog
288288
:user:`Clifford Akai-Nettey<cliffordEmmanuel>`.
289289

290290
:mod:`sklearn.calibration`
291-
............................
291+
..........................
292292

293293
- |Fix| The predict and predict_proba methods of
294-
:class:`calibration.CalibratedClassifierCV can now properly be used on
294+
:class:`calibration.CalibratedClassifierCV` can now properly be used on
295295
prefitted pipelines. :pr:`19641` by :user:`Alek Lefebvre <AlekLefebvre>`
296296

297+
:mod:`sklearn.utils`
298+
....................
299+
300+
- |Fix| Fixed a bug in :func:`utils.sparsefuncs.mean_variance_axis` where the
301+
precision of the computed variance was very poor when the real variance is
302+
exactly zero. :pr:`19766` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
303+
297304
Code and Documentation Contributors
298305
-----------------------------------
299306

‎sklearn/utils/sparsefuncs_fast.pyx

Copy file name to clipboardExpand all lines: sklearn/utils/sparsefuncs_fast.pyx
+42-20Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -124,23 +124,32 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
124124
variances = np.zeros_like(means, dtype=dtype)
125125

126126
cdef:
127-
np.ndarray[floating, ndim=1] sum_weights = \
128-
np.full(fill_value=np.sum(weights), shape=n_features, dtype=dtype)
129-
np.ndarray[floating, ndim=1] sum_weights_nan = \
130-
np.zeros(shape=n_features, dtype=dtype)
131-
np.ndarray[floating, ndim=1] sum_weights_nz = \
132-
np.zeros(shape=n_features, dtype=dtype)
127+
np.ndarray[floating, ndim=1] sum_weights = np.full(
128+
fill_value=np.sum(weights), shape=n_features, dtype=dtype)
129+
np.ndarray[floating, ndim=1] sum_weights_nz = np.zeros(
130+
shape=n_features, dtype=dtype)
131+
132+
np.ndarray[np.uint64_t, ndim=1] counts = np.full(
133+
fill_value=weights.shape[0], shape=n_features, dtype=np.uint64)
134+
np.ndarray[np.uint64_t, ndim=1] counts_nz = np.zeros(
135+
shape=n_features, dtype=np.uint64)
133136

134137
for row_ind in range(len(X_indptr) - 1):
135138
for i in range(X_indptr[row_ind], X_indptr[row_ind + 1]):
136139
col_ind = X_indices[i]
137140
if not isnan(X_data[i]):
138141
means[col_ind] += (X_data[i] * weights[row_ind])
142+
# sum of weights where X[:, col_ind] is non-zero
143+
sum_weights_nz[col_ind] += weights[row_ind]
144+
# number of non-zero elements of X[:, col_ind]
145+
counts_nz[col_ind] += 1
139146
else:
140-
sum_weights_nan[col_ind] += weights[row_ind]
147+
# sum of weights where X[:, col_ind] is not nan
148+
sum_weights[col_ind] -= weights[row_ind]
149+
# number of non nan elements of X[:, col_ind]
150+
counts[col_ind] -= 1
141151

142152
for i in range(n_features):
143-
sum_weights[i] -= sum_weights_nan[i]
144153
means[i] /= sum_weights[i]
145154

146155
for row_ind in range(len(X_indptr) - 1):
@@ -149,10 +158,12 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
149158
if not isnan(X_data[i]):
150159
diff = X_data[i] - means[col_ind]
151160
variances[col_ind] += diff * diff * weights[row_ind]
152-
sum_weights_nz[col_ind] += weights[row_ind]
153161

154162
for i in range(n_features):
155-
variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2
163+
if counts[i] != counts_nz[i]:
164+
# only compute it when it's guaranteed to be non-zero to avoid
165+
# catastrophic cancellation.
166+
variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2
156167
variances[i] /= sum_weights[i]
157168

158169
return means, variances, sum_weights
@@ -228,23 +239,32 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
228239
variances = np.zeros_like(means, dtype=dtype)
229240

230241
cdef:
231-
np.ndarray[floating, ndim=1] sum_weights = \
232-
np.full(fill_value=np.sum(weights), shape=n_features, dtype=dtype)
233-
np.ndarray[floating, ndim=1] sum_weights_nan = \
234-
np.zeros(shape=n_features, dtype=dtype)
235-
np.ndarray[floating, ndim=1] sum_weights_nz = \
236-
np.zeros(shape=n_features, dtype=dtype)
242+
np.ndarray[floating, ndim=1] sum_weights = np.full(
243+
fill_value=np.sum(weights), shape=n_features, dtype=dtype)
244+
np.ndarray[floating, ndim=1] sum_weights_nz = np.zeros(
245+
shape=n_features, dtype=dtype)
246+
247+
np.ndarray[np.uint64_t, ndim=1] counts = np.full(
248+
fill_value=weights.shape[0], shape=n_features, dtype=np.uint64)
249+
np.ndarray[np.uint64_t, ndim=1] counts_nz = np.zeros(
250+
shape=n_features, dtype=np.uint64)
237251

238252
for col_ind in range(n_features):
239253
for i in range(X_indptr[col_ind], X_indptr[col_ind + 1]):
240254
row_ind = X_indices[i]
241255
if not isnan(X_data[i]):
242256
means[col_ind] += (X_data[i] * weights[row_ind])
257+
# sum of weights where X[:, col_ind] is non-zero
258+
sum_weights_nz[col_ind] += weights[row_ind]
259+
# number of non-zero elements of X[:, col_ind]
260+
counts_nz[col_ind] += 1
243261
else:
244-
sum_weights_nan[col_ind] += weights[row_ind]
262+
# sum of weights where X[:, col_ind] is not nan
263+
sum_weights[col_ind] -= weights[row_ind]
264+
# number of non nan elements of X[:, col_ind]
265+
counts[col_ind] -= 1
245266

246267
for i in range(n_features):
247-
sum_weights[i] -= sum_weights_nan[i]
248268
means[i] /= sum_weights[i]
249269

250270
for col_ind in range(n_features):
@@ -253,10 +273,12 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
253273
if not isnan(X_data[i]):
254274
diff = X_data[i] - means[col_ind]
255275
variances[col_ind] += diff * diff * weights[row_ind]
256-
sum_weights_nz[col_ind] += weights[row_ind]
257276

258277
for i in range(n_features):
259-
variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2
278+
if counts[i] != counts_nz[i]:
279+
# only compute it when it's guaranteed to be non-zero to avoid
280+
# catastrophic cancellation.
281+
variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2
260282
variances[i] /= sum_weights[i]
261283

262284
return means, variances, sum_weights

‎sklearn/utils/tests/test_sparsefuncs.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_sparsefuncs.py
+20Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,26 @@ def test_mean_variance_axis0():
5353
assert_array_almost_equal(X_vars, np.var(X_test, axis=0))
5454

5555

56+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
57+
@pytest.mark.parametrize("sparse_constructor", [sp.csr_matrix, sp.csc_matrix])
58+
def test_mean_variance_axis0_precision(dtype, sparse_constructor):
59+
# Check that there's no big loss of precision when the real variance is
60+
# exactly 0. (#19766)
61+
rng = np.random.RandomState(0)
62+
X = np.full(fill_value=100., shape=(1000, 1), dtype=dtype)
63+
# Add some missing records which should be ignored:
64+
missing_indices = rng.choice(np.arange(X.shape[0]), 10, replace=False)
65+
X[missing_indices, 0] = np.nan
66+
X = sparse_constructor(X)
67+
68+
# Random positive weights:
69+
sample_weight = rng.rand(X.shape[0]).astype(dtype)
70+
71+
_, var = mean_variance_axis(X, weights=sample_weight, axis=0)
72+
73+
assert var < np.finfo(dtype).eps
74+
75+
5676
def test_mean_variance_axis1():
5777
X, _ = make_classification(5, 4, random_state=0)
5878
# Sparsify the array a little bit

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.