Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 684b7d1

Browse filesBrowse files
jeremiedbbogrisel
andauthored
FIX detect near constant feature in StandardScaler and linear models (#19788)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 872052b commit 684b7d1
Copy full SHA for 684b7d1

File tree

6 files changed

+166
-57
lines changed
Filter options

6 files changed

+166
-57
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ Changelog
334334
very large values. This problem happens in particular when using a scaler on
335335
sparse data with a constant column with sample weights, in which case
336336
centering is typically disabled. :pr:`19527` by :user:`Oliver Grisel
337-
<ogrisel>` and :user:`Maria Telenczuk <maikia>`.
337+
<ogrisel>` and :user:`Maria Telenczuk <maikia>` and :pr:`19788` by
338+
:user:`Jérémie du Boisberranger <jeremiedbb>`.
338339

339340
- |Fix| :meth:`preprocessing.StandardScaler.inverse_transform` now
340341
correctly handles integer dtypes. :pr:`19356` by :user:`makoeppel`.

‎sklearn/linear_model/_base.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_base.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from ..base import (BaseEstimator, ClassifierMixin, RegressorMixin,
3030
MultiOutputMixin)
31+
from ..preprocessing._data import _is_constant_feature
3132
from ..utils import check_array
3233
from ..utils.validation import FLOAT_DTYPES
3334
from ..utils.validation import _deprecate_positional_args
@@ -39,7 +40,6 @@
3940
from ..utils._seq_dataset import ArrayDataset32, CSRDataset32
4041
from ..utils._seq_dataset import ArrayDataset64, CSRDataset64
4142
from ..utils.validation import check_is_fitted, _check_sample_weight
42-
4343
from ..utils.fixes import delayed
4444

4545
# TODO: bayesian_ridge_regression and bayesian_regression_ard
@@ -271,8 +271,8 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
271271
X_var = X_var.astype(X.dtype, copy=False)
272272
# Detect constant features on the computed variance, before taking
273273
# the np.sqrt. Otherwise constant features cannot be detected with
274-
# sample_weights.
275-
constant_mask = X_var < 10 * np.finfo(X.dtype).eps
274+
# sample weights.
275+
constant_mask = _is_constant_feature(X_var, X_offset, X.shape[0])
276276
X_var *= X.shape[0]
277277
X_scale = np.sqrt(X_var, out=X_var)
278278
X_scale[constant_mask] = 1.

‎sklearn/preprocessing/_data.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/_data.py
+18-1Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@
5757
]
5858

5959

60+
def _is_constant_feature(var, mean, n_samples):
61+
"""Detect if a feature is indistinguishable from a constant feature.
62+
63+
The detection is based on its computed variance and on the theoretical
64+
error bounds of the '2 pass algorithm' for variance computation.
65+
66+
See "Algorithms for computing the sample variance: analysis and
67+
recommendations", by Chan, Golub, and LeVeque.
68+
"""
69+
# In scikit-learn, variance is always computed using float64 accumulators.
70+
eps = np.finfo(np.float64).eps
71+
72+
upper_bound = n_samples * eps * var + (n_samples * mean * eps)**2
73+
return var <= upper_bound
74+
75+
6076
def _handle_zeros_in_scale(scale, copy=True, constant_mask=None):
6177
"""Set scales of near constant features to 1.
6278
@@ -863,7 +879,8 @@ def partial_fit(self, X, y=None, sample_weight=None):
863879
if self.with_std:
864880
# Extract the list of near constant features on the raw variances,
865881
# before taking the square root.
866-
constant_mask = self.var_ < 10 * np.finfo(X.dtype).eps
882+
constant_mask = _is_constant_feature(
883+
self.var_, self.mean_, self.n_samples_seen_)
867884
self.scale_ = _handle_zeros_in_scale(
868885
np.sqrt(self.var_), copy=False, constant_mask=constant_mask)
869886
else:

‎sklearn/preprocessing/tests/test_data.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/tests/test_data.py
+56-7Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,6 @@ def test_standard_scaler_dtype(add_sample_weight, sparse_constructor):
224224
@pytest.mark.parametrize("constant", [0, 1., 100.])
225225
def test_standard_scaler_constant_features(
226226
scaler, add_sample_weight, sparse_constructor, dtype, constant):
227-
if (isinstance(scaler, StandardScaler)
228-
and constant > 1
229-
and sparse_constructor is not np.asarray
230-
and add_sample_weight):
231-
# https://github.com/scikit-learn/scikit-learn/issues/19546
232-
pytest.xfail("Computation of weighted variance is numerically unstable"
233-
" for sparse data. See: #19546.")
234227

235228
if isinstance(scaler, RobustScaler) and add_sample_weight:
236229
pytest.skip(f"{scaler.__class__.__name__} does not yet support"
@@ -269,6 +262,62 @@ def test_standard_scaler_constant_features(
269262
assert_allclose(X_scaled_2, X_scaled_2)
270263

271264

265+
@pytest.mark.parametrize("n_samples", [10, 100, 10_000])
266+
@pytest.mark.parametrize("average", [1e-10, 1, 1e10])
267+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
268+
@pytest.mark.parametrize("array_constructor",
269+
[np.asarray, sparse.csc_matrix, sparse.csr_matrix])
270+
def test_standard_scaler_near_constant_features(n_samples, array_constructor,
271+
average, dtype):
272+
# Check that when the variance is too small (var << mean**2) the feature
273+
# is considered constant and not scaled.
274+
275+
scale_min, scale_max = -30, 19
276+
scales = np.array([10**i for i in range(scale_min, scale_max + 1)],
277+
dtype=dtype)
278+
279+
n_features = scales.shape[0]
280+
X = np.empty((n_samples, n_features), dtype=dtype)
281+
# Make a dataset of known var = scales**2 and mean = average
282+
X[:n_samples//2, :] = average + scales
283+
X[n_samples//2:, :] = average - scales
284+
X_array = array_constructor(X)
285+
286+
scaler = StandardScaler(with_mean=False).fit(X_array)
287+
288+
# StandardScaler uses float64 accumulators even if the data has a float32
289+
# dtype.
290+
eps = np.finfo(np.float64).eps
291+
292+
# if var < bound = N.eps.var + N².eps².mean², the feature is considered
293+
# constant and the scale_ attribute is set to 1.
294+
bounds = n_samples * eps * scales**2 + n_samples**2 * eps**2 * average**2
295+
within_bounds = scales**2 <= bounds
296+
297+
# Check that scale_min is small enough to have some scales below the
298+
# bound and therefore detected as constant:
299+
assert np.any(within_bounds)
300+
301+
# Check that such features are actually treated as constant by the scaler:
302+
assert all(scaler.var_[within_bounds] <= bounds[within_bounds])
303+
assert_allclose(scaler.scale_[within_bounds], 1.)
304+
305+
# Depending the on the dtype of X, some features might not actually be
306+
# representable as non constant for small scales (even if above the
307+
# precision bound of the float64 variance estimate). Such feature should
308+
# be correctly detected as constants with 0 variance by StandardScaler.
309+
representable_diff = X[0, :] - X[-1, :] != 0
310+
assert_allclose(scaler.var_[np.logical_not(representable_diff)], 0)
311+
assert_allclose(scaler.scale_[np.logical_not(representable_diff)], 1)
312+
313+
# The other features are scaled and scale_ is equal to sqrt(var_) assuming
314+
# that scales are large enough for average + scale and average - scale to
315+
# be distinct in X (depending on X's dtype).
316+
common_mask = np.logical_and(scales**2 > bounds, representable_diff)
317+
assert_allclose(scaler.scale_[common_mask],
318+
np.sqrt(scaler.var_)[common_mask])
319+
320+
272321
def test_scale_1d():
273322
# 1-d inputs
274323
X_list = [1., 3., 5., 0.]

‎sklearn/utils/extmath.py

Copy file name to clipboardExpand all lines: sklearn/utils/extmath.py
+35-7Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from . import check_random_state
2020
from ._logistic_sigmoid import _log_logistic_sigmoid
21+
from .fixes import np_version, parse_version
2122
from .sparsefuncs_fast import csr_row_norms
2223
from .validation import check_array
2324
from .validation import _deprecate_positional_args
@@ -767,10 +768,17 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count,
767768
# updated = the aggregated stats
768769
last_sum = last_mean * last_sample_count
769770
if sample_weight is not None:
770-
new_sum = _safe_accumulator_op(np.nansum, X * sample_weight[:, None],
771-
axis=0)
772-
new_sample_count = np.sum(sample_weight[:, None] * (~np.isnan(X)),
773-
axis=0)
771+
if np_version >= parse_version("1.16.6"):
772+
# equivalent to np.nansum(X * sample_weight, axis=0)
773+
# safer because np.float64(X*W) != np.float64(X)*np.float64(W)
774+
# dtype arg of np.matmul only exists since version 1.16
775+
new_sum = _safe_accumulator_op(
776+
np.matmul, sample_weight, np.where(np.isnan(X), 0, X))
777+
else:
778+
new_sum = _safe_accumulator_op(
779+
np.nansum, X * sample_weight[:, None], axis=0)
780+
new_sample_count = _safe_accumulator_op(
781+
np.sum, sample_weight[:, None] * (~np.isnan(X)), axis=0)
774782
else:
775783
new_sum = _safe_accumulator_op(np.nansum, X, axis=0)
776784
new_sample_count = np.sum(~np.isnan(X), axis=0)
@@ -784,10 +792,30 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count,
784792
else:
785793
T = new_sum / new_sample_count
786794
if sample_weight is not None:
787-
new_unnormalized_variance = np.nansum(sample_weight[:, None] *
788-
(X - T)**2, axis=0)
795+
if np_version >= parse_version("1.16.6"):
796+
# equivalent to np.nansum((X-T)**2 * sample_weight, axis=0)
797+
# safer because np.float64(X*W) != np.float64(X)*np.float64(W)
798+
# dtype arg of np.matmul only exists since version 1.16
799+
new_unnormalized_variance = _safe_accumulator_op(
800+
np.matmul, sample_weight,
801+
np.where(np.isnan(X), 0, (X - T)**2))
802+
correction = _safe_accumulator_op(
803+
np.matmul, sample_weight, np.where(np.isnan(X), 0, X - T))
804+
else:
805+
new_unnormalized_variance = _safe_accumulator_op(
806+
np.nansum, (X - T)**2 * sample_weight[:, None], axis=0)
807+
correction = _safe_accumulator_op(
808+
np.nansum, (X - T) * sample_weight[:, None], axis=0)
789809
else:
790-
new_unnormalized_variance = np.nansum((X - T)**2, axis=0)
810+
new_unnormalized_variance = _safe_accumulator_op(
811+
np.nansum, (X - T)**2, axis=0)
812+
correction = _safe_accumulator_op(np.nansum, X - T, axis=0)
813+
814+
# correction term of the corrected 2 pass algorithm.
815+
# See "Algorithms for computing the sample variance: analysis
816+
# and recommendations", by Chan, Golub, and LeVeque.
817+
new_unnormalized_variance -= correction**2 / new_sample_count
818+
791819
last_unnormalized_variance = last_variance * last_sample_count
792820

793821
with np.errstate(divide='ignore', invalid='ignore'):

‎sklearn/utils/sparsefuncs_fast.pyx

Copy file name to clipboardExpand all lines: sklearn/utils/sparsefuncs_fast.pyx
+52-38Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data,
5757
def csr_mean_variance_axis0(X, weights=None, return_sum_weights=False):
5858
"""Compute mean and variance along axis 0 on a CSR matrix
5959
60+
Uses a np.float64 accumulator.
61+
6062
Parameters
6163
----------
6264
X : CSR sparse matrix, shape (n_samples, n_features)
@@ -109,25 +111,18 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
109111
np.npy_intp i
110112
unsigned long long row_ind
111113
integral col_ind
112-
floating diff
114+
np.float64_t diff
113115
# means[j] contains the mean of feature j
114-
np.ndarray[floating, ndim=1] means
116+
np.ndarray[np.float64_t, ndim=1] means = np.zeros(n_features)
115117
# variances[j] contains the variance of feature j
116-
np.ndarray[floating, ndim=1] variances
117-
118-
if floating is float:
119-
dtype = np.float32
120-
else:
121-
dtype = np.float64
118+
np.ndarray[np.float64_t, ndim=1] variances = np.zeros(n_features)
122119

123-
means = np.zeros(n_features, dtype=dtype)
124-
variances = np.zeros_like(means, dtype=dtype)
125-
126-
cdef:
127-
np.ndarray[floating, ndim=1] sum_weights = np.full(
128-
fill_value=np.sum(weights), shape=n_features, dtype=dtype)
129-
np.ndarray[floating, ndim=1] sum_weights_nz = np.zeros(
130-
shape=n_features, dtype=dtype)
120+
np.ndarray[np.float64_t, ndim=1] sum_weights = np.full(
121+
fill_value=np.sum(weights, dtype=np.float64), shape=n_features)
122+
np.ndarray[np.float64_t, ndim=1] sum_weights_nz = np.zeros(
123+
shape=n_features)
124+
np.ndarray[np.float64_t, ndim=1] correction = np.zeros(
125+
shape=n_features)
131126

132127
np.ndarray[np.uint64_t, ndim=1] counts = np.full(
133128
fill_value=weights.shape[0], shape=n_features, dtype=np.uint64)
@@ -138,7 +133,7 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
138133
for i in range(X_indptr[row_ind], X_indptr[row_ind + 1]):
139134
col_ind = X_indices[i]
140135
if not isnan(X_data[i]):
141-
means[col_ind] += (X_data[i] * weights[row_ind])
136+
means[col_ind] += <np.float64_t>(X_data[i]) * weights[row_ind]
142137
# sum of weights where X[:, col_ind] is non-zero
143138
sum_weights_nz[col_ind] += weights[row_ind]
144139
# number of non-zero elements of X[:, col_ind]
@@ -157,21 +152,35 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
157152
col_ind = X_indices[i]
158153
if not isnan(X_data[i]):
159154
diff = X_data[i] - means[col_ind]
155+
# correction term of the corrected 2 pass algorithm.
156+
# See "Algorithms for computing the sample variance: analysis
157+
# and recommendations", by Chan, Golub, and LeVeque.
158+
correction[col_ind] += diff * weights[row_ind]
160159
variances[col_ind] += diff * diff * weights[row_ind]
161160

162161
for i in range(n_features):
162+
if counts[i] != counts_nz[i]:
163+
correction[i] -= (sum_weights[i] - sum_weights_nz[i]) * means[i]
164+
correction[i] = correction[i]**2 / sum_weights[i]
163165
if counts[i] != counts_nz[i]:
164166
# only compute it when it's guaranteed to be non-zero to avoid
165167
# catastrophic cancellation.
166168
variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2
167-
variances[i] /= sum_weights[i]
169+
variances[i] = (variances[i] - correction[i]) / sum_weights[i]
168170

169-
return means, variances, sum_weights
171+
if floating is float:
172+
return (np.array(means, dtype=np.float32),
173+
np.array(variances, dtype=np.float32),
174+
np.array(sum_weights, dtype=np.float32))
175+
else:
176+
return means, variances, sum_weights
170177

171178

172179
def csc_mean_variance_axis0(X, weights=None, return_sum_weights=False):
173180
"""Compute mean and variance along axis 0 on a CSC matrix
174181
182+
Uses a np.float64 accumulator.
183+
175184
Parameters
176185
----------
177186
X : CSC sparse matrix, shape (n_samples, n_features)
@@ -224,25 +233,18 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
224233
np.npy_intp i
225234
unsigned long long col_ind
226235
integral row_ind
227-
floating diff
236+
np.float64_t diff
228237
# means[j] contains the mean of feature j
229-
np.ndarray[floating, ndim=1] means
238+
np.ndarray[np.float64_t, ndim=1] means = np.zeros(n_features)
230239
# variances[j] contains the variance of feature j
231-
np.ndarray[floating, ndim=1] variances
232-
233-
if floating is float:
234-
dtype = np.float32
235-
else:
236-
dtype = np.float64
240+
np.ndarray[np.float64_t, ndim=1] variances = np.zeros(n_features)
237241

238-
means = np.zeros(n_features, dtype=dtype)
239-
variances = np.zeros_like(means, dtype=dtype)
240-
241-
cdef:
242-
np.ndarray[floating, ndim=1] sum_weights = np.full(
243-
fill_value=np.sum(weights), shape=n_features, dtype=dtype)
244-
np.ndarray[floating, ndim=1] sum_weights_nz = np.zeros(
245-
shape=n_features, dtype=dtype)
242+
np.ndarray[np.float64_t, ndim=1] sum_weights = np.full(
243+
fill_value=np.sum(weights, dtype=np.float64), shape=n_features)
244+
np.ndarray[np.float64_t, ndim=1] sum_weights_nz = np.zeros(
245+
shape=n_features)
246+
np.ndarray[np.float64_t, ndim=1] correction = np.zeros(
247+
shape=n_features)
246248

247249
np.ndarray[np.uint64_t, ndim=1] counts = np.full(
248250
fill_value=weights.shape[0], shape=n_features, dtype=np.uint64)
@@ -253,7 +255,7 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
253255
for i in range(X_indptr[col_ind], X_indptr[col_ind + 1]):
254256
row_ind = X_indices[i]
255257
if not isnan(X_data[i]):
256-
means[col_ind] += (X_data[i] * weights[row_ind])
258+
means[col_ind] += <np.float64_t>(X_data[i]) * weights[row_ind]
257259
# sum of weights where X[:, col_ind] is non-zero
258260
sum_weights_nz[col_ind] += weights[row_ind]
259261
# number of non-zero elements of X[:, col_ind]
@@ -272,16 +274,28 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
272274
row_ind = X_indices[i]
273275
if not isnan(X_data[i]):
274276
diff = X_data[i] - means[col_ind]
277+
# correction term of the corrected 2 pass algorithm.
278+
# See "Algorithms for computing the sample variance: analysis
279+
# and recommendations", by Chan, Golub, and LeVeque.
280+
correction[col_ind] += diff * weights[row_ind]
275281
variances[col_ind] += diff * diff * weights[row_ind]
276282

277283
for i in range(n_features):
284+
if counts[i] != counts_nz[i]:
285+
correction[i] -= (sum_weights[i] - sum_weights_nz[i]) * means[i]
286+
correction[i] = correction[i]**2 / sum_weights[i]
278287
if counts[i] != counts_nz[i]:
279288
# only compute it when it's guaranteed to be non-zero to avoid
280289
# catastrophic cancellation.
281290
variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2
282-
variances[i] /= sum_weights[i]
291+
variances[i] = (variances[i] - correction[i]) / sum_weights[i]
283292

284-
return means, variances, sum_weights
293+
if floating is float:
294+
return (np.array(means, dtype=np.float32),
295+
np.array(variances, dtype=np.float32),
296+
np.array(sum_weights, dtype=np.float32))
297+
else:
298+
return means, variances, sum_weights
285299

286300

287301
def incr_mean_variance_axis0(X, last_mean, last_var, last_n, weights=None):

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.