@@ -124,23 +124,32 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
124
124
variances = np.zeros_like(means, dtype = dtype)
125
125
126
126
cdef:
127
- np.ndarray[floating, ndim= 1 ] sum_weights = \
128
- np.full(fill_value = np.sum(weights), shape = n_features, dtype = dtype)
129
- np.ndarray[floating, ndim= 1 ] sum_weights_nan = \
130
- np.zeros(shape = n_features, dtype = dtype)
131
- np.ndarray[floating, ndim= 1 ] sum_weights_nz = \
132
- np.zeros(shape = n_features, dtype = dtype)
127
+ np.ndarray[floating, ndim= 1 ] sum_weights = np.full(
128
+ fill_value = np.sum(weights), shape = n_features, dtype = dtype)
129
+ np.ndarray[floating, ndim= 1 ] sum_weights_nz = np.zeros(
130
+ shape = n_features, dtype = dtype)
131
+
132
+ np.ndarray[np.uint64_t, ndim= 1 ] counts = np.full(
133
+ fill_value = weights.shape[0 ], shape = n_features, dtype = np.uint64)
134
+ np.ndarray[np.uint64_t, ndim= 1 ] counts_nz = np.zeros(
135
+ shape = n_features, dtype = np.uint64)
133
136
134
137
for row_ind in range (len (X_indptr) - 1 ):
135
138
for i in range (X_indptr[row_ind], X_indptr[row_ind + 1 ]):
136
139
col_ind = X_indices[i]
137
140
if not isnan(X_data[i]):
138
141
means[col_ind] += (X_data[i] * weights[row_ind])
142
+ # sum of weights where X[:, col_ind] is non-zero
143
+ sum_weights_nz[col_ind] += weights[row_ind]
144
+ # number of non-zero elements of X[:, col_ind]
145
+ counts_nz[col_ind] += 1
139
146
else :
140
- sum_weights_nan[col_ind] += weights[row_ind]
147
+ # sum of weights where X[:, col_ind] is not nan
148
+ sum_weights[col_ind] -= weights[row_ind]
149
+ # number of non nan elements of X[:, col_ind]
150
+ counts[col_ind] -= 1
141
151
142
152
for i in range (n_features):
143
- sum_weights[i] -= sum_weights_nan[i]
144
153
means[i] /= sum_weights[i]
145
154
146
155
for row_ind in range (len (X_indptr) - 1 ):
@@ -149,10 +158,12 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
149
158
if not isnan(X_data[i]):
150
159
diff = X_data[i] - means[col_ind]
151
160
variances[col_ind] += diff * diff * weights[row_ind]
152
- sum_weights_nz[col_ind] += weights[row_ind]
153
161
154
162
for i in range (n_features):
155
- variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]** 2
163
+ if counts[i] != counts_nz[i]:
164
+ # only compute it when it's guaranteed to be non-zero to avoid
165
+ # catastrophic cancellation.
166
+ variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]** 2
156
167
variances[i] /= sum_weights[i]
157
168
158
169
return means, variances, sum_weights
@@ -228,23 +239,32 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
228
239
variances = np.zeros_like(means, dtype = dtype)
229
240
230
241
cdef:
231
- np.ndarray[floating, ndim= 1 ] sum_weights = \
232
- np.full(fill_value = np.sum(weights), shape = n_features, dtype = dtype)
233
- np.ndarray[floating, ndim= 1 ] sum_weights_nan = \
234
- np.zeros(shape = n_features, dtype = dtype)
235
- np.ndarray[floating, ndim= 1 ] sum_weights_nz = \
236
- np.zeros(shape = n_features, dtype = dtype)
242
+ np.ndarray[floating, ndim= 1 ] sum_weights = np.full(
243
+ fill_value = np.sum(weights), shape = n_features, dtype = dtype)
244
+ np.ndarray[floating, ndim= 1 ] sum_weights_nz = np.zeros(
245
+ shape = n_features, dtype = dtype)
246
+
247
+ np.ndarray[np.uint64_t, ndim= 1 ] counts = np.full(
248
+ fill_value = weights.shape[0 ], shape = n_features, dtype = np.uint64)
249
+ np.ndarray[np.uint64_t, ndim= 1 ] counts_nz = np.zeros(
250
+ shape = n_features, dtype = np.uint64)
237
251
238
252
for col_ind in range (n_features):
239
253
for i in range (X_indptr[col_ind], X_indptr[col_ind + 1 ]):
240
254
row_ind = X_indices[i]
241
255
if not isnan(X_data[i]):
242
256
means[col_ind] += (X_data[i] * weights[row_ind])
257
+ # sum of weights where X[:, col_ind] is non-zero
258
+ sum_weights_nz[col_ind] += weights[row_ind]
259
+ # number of non-zero elements of X[:, col_ind]
260
+ counts_nz[col_ind] += 1
243
261
else :
244
- sum_weights_nan[col_ind] += weights[row_ind]
262
+ # sum of weights where X[:, col_ind] is not nan
263
+ sum_weights[col_ind] -= weights[row_ind]
264
+ # number of non nan elements of X[:, col_ind]
265
+ counts[col_ind] -= 1
245
266
246
267
for i in range (n_features):
247
- sum_weights[i] -= sum_weights_nan[i]
248
268
means[i] /= sum_weights[i]
249
269
250
270
for col_ind in range (n_features):
@@ -253,10 +273,12 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
253
273
if not isnan(X_data[i]):
254
274
diff = X_data[i] - means[col_ind]
255
275
variances[col_ind] += diff * diff * weights[row_ind]
256
- sum_weights_nz[col_ind] += weights[row_ind]
257
276
258
277
for i in range (n_features):
259
- variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]** 2
278
+ if counts[i] != counts_nz[i]:
279
+ # only compute it when it's guaranteed to be non-zero to avoid
280
+ # catastrophic cancellation.
281
+ variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]** 2
260
282
variances[i] /= sum_weights[i]
261
283
262
284
return means, variances, sum_weights
0 commit comments