Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 94876ae

Browse filesBrowse files
snath-xocs-banachogriselShruti Nath
authored andcommitted
Fix sample weight support in ElasticnectCV (scikit-learn#29442)
Co-authored-by: Mr. Snrub <45150804+s-banach@users.noreply.github.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Shruti Nath <shrutinath@Shrutis-Laptop.local>
1 parent 4434782 commit 94876ae
Copy full SHA for 94876ae

File tree

Expand file treeCollapse file tree

3 files changed

+123
-74
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+123
-74
lines changed

‎doc/whats_new/v1.6.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.6.rst
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,11 @@ Changelog
261261
has no effect. `copy_X` will be removed in 1.8.
262262
:pr:`29105` by :user:`Adam Li <adam2392>`.
263263

264+
- |Fix| :class:`linear_model.LassoCV` and :class:`linear_model.ElasticNetCV` now
265+
take sample weights into accounts to define the search grid for the internally tuned
266+
`alpha` hyper-parameter. :pr:`29442` by :user:`John Hopfensperger <s-banach> and
267+
:user:`Shruti Nath <snath-xoc>`.
268+
264269
:mod:`sklearn.manifold`
265270
.......................
266271

‎sklearn/linear_model/_coordinate_descent.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_coordinate_descent.py
+34-34Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _alpha_grid(
9999
eps=1e-3,
100100
n_alphas=100,
101101
copy_X=True,
102+
sample_weight=None,
102103
):
103104
"""Compute the grid of alpha values for elastic net parameter search
104105
@@ -133,6 +134,8 @@ def _alpha_grid(
133134
134135
copy_X : bool, default=True
135136
If ``True``, X will be copied; else, it may be overwritten.
137+
138+
sample_weight : ndarray of shape (n_samples,), default=None
136139
"""
137140
if l1_ratio == 0:
138141
raise ValueError(
@@ -141,43 +144,39 @@ def _alpha_grid(
141144
"your estimator with the appropriate `alphas=` "
142145
"argument."
143146
)
144-
n_samples = len(y)
145-
146-
sparse_center = False
147-
if Xy is None:
148-
X_sparse = sparse.issparse(X)
149-
sparse_center = X_sparse and fit_intercept
150-
X = check_array(
151-
X, accept_sparse="csc", copy=(copy_X and fit_intercept and not X_sparse)
147+
if Xy is not None:
148+
Xyw = Xy
149+
else:
150+
X, y, X_offset, _, _ = _preprocess_data(
151+
X,
152+
y,
153+
fit_intercept=fit_intercept,
154+
copy=copy_X,
155+
sample_weight=sample_weight,
156+
check_input=False,
152157
)
153-
if not X_sparse:
154-
# X can be touched inplace thanks to the above line
155-
X, y, _, _, _ = _preprocess_data(
156-
X, y, fit_intercept=fit_intercept, copy=False
157-
)
158-
Xy = safe_sparse_dot(X.T, y, dense_output=True)
159-
160-
if sparse_center:
161-
# Workaround to find alpha_max for sparse matrices.
162-
# since we should not destroy the sparsity of such matrices.
163-
_, _, X_offset, _, X_scale = _preprocess_data(
164-
X, y, fit_intercept=fit_intercept
165-
)
166-
mean_dot = X_offset * np.sum(y)
167-
168-
if Xy.ndim == 1:
169-
Xy = Xy[:, np.newaxis]
170-
171-
if sparse_center:
172-
if fit_intercept:
173-
Xy -= mean_dot[:, np.newaxis]
158+
if sample_weight is not None:
159+
if y.ndim > 1:
160+
yw = y * sample_weight.reshape(-1, 1)
161+
else:
162+
yw = y * sample_weight
163+
else:
164+
yw = y
165+
if sparse.issparse(X):
166+
Xyw = safe_sparse_dot(X.T, yw, dense_output=True) - np.sum(yw) * X_offset
167+
else:
168+
Xyw = np.dot(X.T, yw)
174169

175-
alpha_max = np.sqrt(np.sum(Xy**2, axis=1)).max() / (n_samples * l1_ratio)
170+
if Xyw.ndim == 1:
171+
Xyw = Xyw[:, np.newaxis]
172+
if sample_weight is not None:
173+
n_samples = sample_weight.sum()
174+
else:
175+
n_samples = X.shape[0]
176+
alpha_max = np.sqrt(np.sum(Xyw**2, axis=1)).max() / (n_samples * l1_ratio)
176177

177-
if alpha_max <= np.finfo(float).resolution:
178-
alphas = np.empty(n_alphas)
179-
alphas.fill(np.finfo(float).resolution)
180-
return alphas
178+
if alpha_max <= np.finfo(np.float64).resolution:
179+
return np.full(n_alphas, np.finfo(np.float64).resolution)
181180

182181
return np.geomspace(alpha_max, alpha_max * eps, num=n_alphas)
183182

@@ -1704,6 +1703,7 @@ def fit(self, X, y, sample_weight=None, **params):
17041703
eps=self.eps,
17051704
n_alphas=self.n_alphas,
17061705
copy_X=self.copy_X,
1706+
sample_weight=sample_weight,
17071707
)
17081708
for l1_ratio in l1_ratios
17091709
]

‎sklearn/linear_model/tests/test_coordinate_descent.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_coordinate_descent.py
+84-40Lines changed: 84 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
assert_almost_equal,
4949
assert_array_almost_equal,
5050
assert_array_equal,
51+
assert_array_less,
5152
ignore_warnings,
5253
)
5354
from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS
@@ -1304,55 +1305,78 @@ def test_enet_sample_weight_consistency(
13041305

13051306
@pytest.mark.parametrize("fit_intercept", [True, False])
13061307
@pytest.mark.parametrize("sparse_container", [None] + CSC_CONTAINERS)
1307-
def test_enet_cv_sample_weight_correctness(fit_intercept, sparse_container):
1308-
"""Test that ElasticNetCV with sample weights gives correct results."""
1309-
rng = np.random.RandomState(42)
1310-
n_splits, n_samples, n_features = 3, 10, 5
1311-
X = rng.rand(n_splits * n_samples, n_features)
1308+
def test_enet_cv_sample_weight_correctness(
1309+
fit_intercept, sparse_container, global_random_seed
1310+
):
1311+
"""Test that ElasticNetCV with sample weights gives correct results.
1312+
1313+
We fit the same model twice, once with weighted training data, once with repeated
1314+
data points in the training data and check that both models converge to the
1315+
same solution.
1316+
1317+
Since this model uses an internal cross-validation scheme to tune the alpha
1318+
regularization parameter, we make sure that the repetitions only occur within
1319+
a specific CV group. Data points belonging to other CV groups stay
1320+
unit-weighted / "unrepeated".
1321+
"""
1322+
rng = np.random.RandomState(global_random_seed)
1323+
n_splits, n_samples_per_cv, n_features = 3, 10, 5
1324+
X_with_weights = rng.rand(n_splits * n_samples_per_cv, n_features)
13121325
beta = rng.rand(n_features)
13131326
beta[0:2] = 0
1314-
y = X @ beta + rng.rand(n_splits * n_samples)
1315-
sw = np.ones_like(y)
1327+
y_with_weights = X_with_weights @ beta + rng.rand(n_splits * n_samples_per_cv)
1328+
13161329
if sparse_container is not None:
1317-
X = sparse_container(X)
1330+
X_with_weights = sparse_container(X_with_weights)
13181331
params = dict(tol=1e-6)
13191332

1320-
# Set alphas, otherwise the two cv models might use different ones.
1321-
if fit_intercept:
1322-
alphas = np.linspace(0.001, 0.01, num=91)
1323-
else:
1324-
alphas = np.linspace(0.01, 0.1, num=91)
1325-
1326-
# We weight the first fold 2 times more.
1327-
sw[:n_samples] = 2
1328-
groups_sw = np.r_[
1329-
np.full(n_samples, 0), np.full(n_samples, 1), np.full(n_samples, 2)
1330-
]
1331-
splits_sw = list(LeaveOneGroupOut().split(X, groups=groups_sw))
1332-
reg_sw = ElasticNetCV(
1333-
alphas=alphas, cv=splits_sw, fit_intercept=fit_intercept, **params
1333+
# Assign random integer weights only to the first cross-validation group.
1334+
# The samples in the other cross-validation groups are left with unit
1335+
# weights.
1336+
1337+
sw = np.ones_like(y_with_weights)
1338+
sw[:n_samples_per_cv] = rng.randint(0, 5, size=n_samples_per_cv)
1339+
groups_with_weights = np.concatenate(
1340+
[
1341+
np.full(n_samples_per_cv, 0),
1342+
np.full(n_samples_per_cv, 1),
1343+
np.full(n_samples_per_cv, 2),
1344+
]
1345+
)
1346+
splits_with_weights = list(
1347+
LeaveOneGroupOut().split(X_with_weights, groups=groups_with_weights)
1348+
)
1349+
reg_with_weights = ElasticNetCV(
1350+
cv=splits_with_weights, fit_intercept=fit_intercept, **params
13341351
)
1335-
reg_sw.fit(X, y, sample_weight=sw)
13361352

1337-
# We repeat the first fold 2 times and provide splits ourselves
1353+
reg_with_weights.fit(X_with_weights, y_with_weights, sample_weight=sw)
1354+
13381355
if sparse_container is not None:
1339-
X = X.toarray()
1340-
X = np.r_[X[:n_samples], X]
1356+
X_with_weights = X_with_weights.toarray()
1357+
X_with_repetitions = np.repeat(X_with_weights, sw.astype(int), axis=0)
13411358
if sparse_container is not None:
1342-
X = sparse_container(X)
1343-
y = np.r_[y[:n_samples], y]
1344-
groups = np.r_[
1345-
np.full(2 * n_samples, 0), np.full(n_samples, 1), np.full(n_samples, 2)
1346-
]
1347-
splits = list(LeaveOneGroupOut().split(X, groups=groups))
1348-
reg = ElasticNetCV(alphas=alphas, cv=splits, fit_intercept=fit_intercept, **params)
1349-
reg.fit(X, y)
1359+
X_with_repetitions = sparse_container(X_with_repetitions)
1360+
1361+
y_with_repetitions = np.repeat(y_with_weights, sw.astype(int), axis=0)
1362+
groups_with_repetitions = np.repeat(groups_with_weights, sw.astype(int), axis=0)
1363+
1364+
splits_with_repetitions = list(
1365+
LeaveOneGroupOut().split(X_with_repetitions, groups=groups_with_repetitions)
1366+
)
1367+
reg_with_repetitions = ElasticNetCV(
1368+
cv=splits_with_repetitions, fit_intercept=fit_intercept, **params
1369+
)
1370+
reg_with_repetitions.fit(X_with_repetitions, y_with_repetitions)
13501371

1351-
# ensure that we chose meaningful alphas, i.e. not boundaries
1352-
assert alphas[0] < reg.alpha_ < alphas[-1]
1353-
assert reg_sw.alpha_ == reg.alpha_
1354-
assert_allclose(reg_sw.coef_, reg.coef_)
1355-
assert reg_sw.intercept_ == pytest.approx(reg.intercept_)
1372+
# Check that the alpha selection process is the same:
1373+
assert_allclose(reg_with_weights.mse_path_, reg_with_repetitions.mse_path_)
1374+
assert_allclose(reg_with_weights.alphas_, reg_with_repetitions.alphas_)
1375+
assert reg_with_weights.alpha_ == pytest.approx(reg_with_repetitions.alpha_)
1376+
1377+
# Check that the final model coefficients are the same:
1378+
assert_allclose(reg_with_weights.coef_, reg_with_repetitions.coef_, atol=1e-10)
1379+
assert reg_with_weights.intercept_ == pytest.approx(reg_with_repetitions.intercept_)
13561380

13571381

13581382
@pytest.mark.parametrize("sample_weight", [False, True])
@@ -1444,9 +1468,29 @@ def test_enet_cv_sample_weight_consistency(
14441468
assert_allclose(reg.intercept_, intercept)
14451469

14461470

1471+
@pytest.mark.parametrize("X_is_sparse", [False, True])
1472+
@pytest.mark.parametrize("fit_intercept", [False, True])
1473+
@pytest.mark.parametrize("sample_weight", [np.array([10, 1, 10, 1]), None])
1474+
def test_enet_alpha_max_sample_weight(X_is_sparse, fit_intercept, sample_weight):
1475+
X = np.array([[3.0, 1.0], [2.0, 5.0], [5.0, 3.0], [1.0, 4.0]])
1476+
beta = np.array([1, 1])
1477+
y = X @ beta
1478+
if X_is_sparse:
1479+
X = sparse.csc_matrix(X)
1480+
# Test alpha_max makes coefs zero.
1481+
reg = ElasticNetCV(n_alphas=1, cv=2, eps=1, fit_intercept=fit_intercept)
1482+
reg.fit(X, y, sample_weight=sample_weight)
1483+
assert_allclose(reg.coef_, 0, atol=1e-5)
1484+
alpha_max = reg.alpha_
1485+
# Test smaller alpha makes coefs nonzero.
1486+
reg = ElasticNet(alpha=0.99 * alpha_max, fit_intercept=fit_intercept)
1487+
reg.fit(X, y, sample_weight=sample_weight)
1488+
assert_array_less(1e-3, np.max(np.abs(reg.coef_)))
1489+
1490+
14471491
@pytest.mark.parametrize("estimator", [ElasticNetCV, LassoCV])
14481492
def test_linear_models_cv_fit_with_loky(estimator):
1449-
# LinearModelsCV.fit performs inplace operations on fancy-indexed memmapped
1493+
# LinearModelsCV.fit performs operations on fancy-indexed memmapped
14501494
# data when using the loky backend, causing an error due to unexpected
14511495
# behavior of fancy indexing of read-only memmaps (cf. numpy#14132).
14521496

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.