Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Fix elasticnet cv sample weight #29308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8d4b501
Update _alpha_grid to take sample_weight
s-banach Apr 4, 2022
2f494db
Add a simple test for alpha_max with sample_weight
s-banach Apr 4, 2022
fa2c821
Update test
s-banach Apr 4, 2022
75e6584
Clarify _alpha_grid.
s-banach Apr 4, 2022
8b6cfc0
Clarify notation
s-banach Apr 5, 2022
2ba4c57
Use Xy if it is provided.
s-banach Jul 2, 2022
5d1f5e7
Update test, check alpha_max is not too large
s-banach Jul 2, 2022
dce169c
Fix test that alpha_max is not too large.
s-banach Jul 2, 2022
380c21f
Test alpha_max without sample_weight.
s-banach Jul 6, 2022
c187cf7
fix elasticnetcv sample weighting adapted from previous commit by s-b…
snath-xoc Jun 19, 2024
40d8b30
Update _preprocess_data inputs in _coordinate_descent.py
snath-xoc Jun 20, 2024
85062c0
added tests for repeated vs weighted on cyclic ElasticNetCV and modif…
snath-xoc Jun 20, 2024
c649d36
Merge branch 'main' into fix_elasticnet_cv_sample_weight
ogrisel Jun 27, 2024
41fcb5f
added to changelog and changed seeding in tests
snath-xoc Jun 27, 2024
335137d
[all random seeds] test_enet_cv_sample_weight
snath-xoc Jun 27, 2024
36cc847
Merge branch 'main' into fix_elasticnet_cv_sample_weight
ogrisel Jun 28, 2024
fec4f74
Revert unrelated changes
ogrisel Jun 28, 2024
c41a8ee
merged test into test_enet_cv_sample_weight_correctness
snath-xoc Jun 28, 2024
ac9f090
changed sample weight to be explicitly set as integers in sklearn/lin…
snath-xoc Jun 29, 2024
bf62b35
Merge branch 'main' into fix_elasticnet_cv_sample_weight
snath-xoc Jul 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions 6 doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123455 is the *pull request* number, not the issue number.

:mod:`sklearn.linear_model`
..........................
- |Fix| :func:`_coordinate_descent._alpha_grid` adapted to account for sample weights.
:pr:`23045`by :user:`John Hopfensperger <s-banach>`and :pr:`29308`by :user:`Shruti Nath <snath-xoc>`.


:mod:`sklearn.base`
...................

Expand Down
71 changes: 34 additions & 37 deletions 71 sklearn/linear_model/_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _alpha_grid(
eps=1e-3,
n_alphas=100,
copy_X=True,
sample_weight=None,
):
"""Compute the grid of alpha values for elastic net parameter search

Expand Down Expand Up @@ -132,6 +133,8 @@ def _alpha_grid(

copy_X : bool, default=True
If ``True``, X will be copied; else, it may be overwritten.

sample_weight : ndarray of shape (n_samples,)
"""
if l1_ratio == 0:
raise ValueError(
Expand All @@ -140,43 +143,40 @@ def _alpha_grid(
"your estimator with the appropriate `alphas=` "
"argument."
)
n_samples = len(y)

sparse_center = False
if Xy is None:
X_sparse = sparse.issparse(X)
sparse_center = X_sparse and fit_intercept
X = check_array(
X, accept_sparse="csc", copy=(copy_X and fit_intercept and not X_sparse)
if Xy is not None:
Xyw = Xy
else:
X, y, X_offset, _, _ = _preprocess_data(
X,
y,
fit_intercept=fit_intercept,
copy=copy_X,
sample_weight=sample_weight,
check_input=False,
)
if not X_sparse:
# X can be touched inplace thanks to the above line
X, y, _, _, _ = _preprocess_data(
X, y, fit_intercept=fit_intercept, copy=False
)
Xy = safe_sparse_dot(X.T, y, dense_output=True)

if sparse_center:
# Workaround to find alpha_max for sparse matrices.
# since we should not destroy the sparsity of such matrices.
_, _, X_offset, _, X_scale = _preprocess_data(
X, y, fit_intercept=fit_intercept
)
mean_dot = X_offset * np.sum(y)

if Xy.ndim == 1:
Xy = Xy[:, np.newaxis]
if sample_weight is not None:
if y.ndim > 1:
yw = y * np.broadcast_to(sample_weight.reshape(-1, 1), y.shape)

if sparse_center:
if fit_intercept:
Xy -= mean_dot[:, np.newaxis]
else:
yw = y * sample_weight
else:
yw = y
if sparse.issparse(X):
Xyw = safe_sparse_dot(X.T, yw, dense_output=True) - np.sum(yw) * X_offset
else:
Xyw = np.dot(X.T, yw)

alpha_max = np.sqrt(np.sum(Xy**2, axis=1)).max() / (n_samples * l1_ratio)
if Xyw.ndim == 1:
Xyw = Xyw[:, np.newaxis]
if sample_weight is not None:
n_samples = sample_weight.sum()
else:
n_samples = X.shape[0]
alpha_max = np.sqrt(np.sum(Xyw**2, axis=1)).max() / (n_samples * l1_ratio)

if alpha_max <= np.finfo(float).resolution:
alphas = np.empty(n_alphas)
alphas.fill(np.finfo(float).resolution)
return alphas
if alpha_max <= np.finfo(np.float64).resolution:
return np.full(n_alphas, np.finfo(np.float64).resolution)

return np.geomspace(alpha_max, alpha_max * eps, num=n_alphas)

Expand Down Expand Up @@ -979,7 +979,6 @@ def fit(self, X, y, sample_weight=None, check_input=True):
accept_sparse="csc",
order="F",
dtype=[np.float64, np.float32],
force_writeable=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the removal of this line and the other force_writeable=True lines below were not intentional (maybe when resolving conflicts with main)?

I think this is what causes the test failures on the CI.

accept_large_sparse=False,
copy=X_copied,
multi_output=True,
Expand Down Expand Up @@ -1608,7 +1607,6 @@ def fit(self, X, y, sample_weight=None, **params):
check_X_params = dict(
accept_sparse="csc",
dtype=[np.float64, np.float32],
force_writeable=True,
copy=False,
accept_large_sparse=False,
)
Expand All @@ -1634,7 +1632,6 @@ def fit(self, X, y, sample_weight=None, **params):
accept_sparse="csc",
dtype=[np.float64, np.float32],
order="F",
force_writeable=True,
copy=copy_X,
)
X, y = self._validate_data(
Expand Down Expand Up @@ -1702,6 +1699,7 @@ def fit(self, X, y, sample_weight=None, **params):
eps=self.eps,
n_alphas=self.n_alphas,
copy_X=self.copy_X,
sample_weight=sample_weight,
)
for l1_ratio in l1_ratios
]
Expand Down Expand Up @@ -2511,7 +2509,6 @@ def fit(self, X, y):
check_X_params = dict(
dtype=[np.float64, np.float32],
order="F",
force_writeable=True,
copy=self.copy_X and self.fit_intercept,
)
check_y_params = dict(ensure_2d=False, order="F")
Expand Down
48 changes: 32 additions & 16 deletions 48 sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
assert_array_less,
ignore_warnings,
)
from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS
Expand Down Expand Up @@ -1319,39 +1320,34 @@ def test_enet_cv_sample_weight_correctness(fit_intercept, sparse_container):
X = sparse_container(X)
params = dict(tol=1e-6)

# Set alphas, otherwise the two cv models might use different ones.
if fit_intercept:
alphas = np.linspace(0.001, 0.01, num=91)
else:
alphas = np.linspace(0.01, 0.1, num=91)

# We weight the first fold 2 times more.
sw[:n_samples] = 2
# We weight the first fold n times more.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# We weight the first fold n times more.
# We re-weight the first cross-validation group with random integer weights.
# The samples in the other groups are left with unit weights.

sw[:n_samples] = rng.randint(0, 5, size=sw[:n_samples].shape[0])
groups_sw = np.r_[
np.full(n_samples, 0), np.full(n_samples, 1), np.full(n_samples, 2)
]
splits_sw = list(LeaveOneGroupOut().split(X, groups=groups_sw))
reg_sw = ElasticNetCV(
alphas=alphas, cv=splits_sw, fit_intercept=fit_intercept, **params
)
reg_sw = ElasticNetCV(cv=splits_sw, fit_intercept=fit_intercept, **params)
reg_sw.fit(X, y, sample_weight=sw)

# We repeat the first fold 2 times and provide splits ourselves
if sparse_container is not None:
X = X.toarray()
X = np.r_[X[:n_samples], X]
X_rep = np.repeat(X, sw.astype(int), axis=0)
##Need to know number of repitions made in total
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
##Need to know number of repitions made in total
# Inspect the total number of random repetitions so as to adjust the size of
# the first cross-validation group accordingly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think that computing the number of repetitions is not needed, see the other suggestions below.

n_reps = X_rep.shape[0] - X.shape[0]
X = X_rep
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather not rename change the X variable to keep the code easier to follow.

Maybe you could instead name the variables X_with_weights, y_with_weights, groups_with_weights on one hand and X_with_repetitions, y_with_repetitions and groups_with_repetitions on the other hand.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And similarly for the names of the 2 cross-validation splitters (if you adapt the code to use metadata routing) or the results of their splits if you prefer to precompute them instead of leveraging metadata routing.

if sparse_container is not None:
X = sparse_container(X)
y = np.r_[y[:n_samples], y]
y = np.repeat(y, sw.astype(int), axis=0)
groups = np.r_[
np.full(2 * n_samples, 0), np.full(n_samples, 1), np.full(n_samples, 2)
np.full(n_reps + n_samples, 0), np.full(n_samples, 1), np.full(n_samples, 2)
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using n_reps, you could use:

groups_with_repetitions = np.repeat(groups_with_weights, sw.astype(int), axis=0)

as is done for X and y.

splits = list(LeaveOneGroupOut().split(X, groups=groups))
reg = ElasticNetCV(alphas=alphas, cv=splits, fit_intercept=fit_intercept, **params)
reg = ElasticNetCV(cv=splits, fit_intercept=fit_intercept, **params)
reg.fit(X, y)

# ensure that we chose meaningful alphas, i.e. not boundaries
assert alphas[0] < reg.alpha_ < alphas[-1]
assert_allclose(reg_sw.alphas_, reg.alphas_)
assert reg_sw.alpha_ == reg.alpha_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also compare the values of the mse_path_ attributes. prior to comparing the coef_ values.

assert_allclose(reg_sw.coef_, reg.coef_)
assert reg_sw.intercept_ == pytest.approx(reg.intercept_)
Expand Down Expand Up @@ -1392,6 +1388,26 @@ def test_enet_cv_grid_search(sample_weight):
assert reg.alpha_ == pytest.approx(gs.best_params_["alpha"])


@pytest.mark.parametrize("sparseX", [False, True])
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("sample_weight", [np.array([10, 1, 10, 1]), None])
def test_enet_alpha_max_sample_weight(sparseX, fit_intercept, sample_weight):
X = np.array([[3.0, 1.0], [2.0, 5.0], [5.0, 3.0], [1.0, 4.0]])
beta = np.array([1, 1])
y = X @ beta
if sparseX:
X = sparse.csc_matrix(X)
# Test alpha_max makes coefs zero.
reg = ElasticNetCV(n_alphas=1, cv=2, eps=1, fit_intercept=fit_intercept)
reg.fit(X, y, sample_weight=sample_weight)
assert_almost_equal(reg.coef_, 0)
alpha_max = reg.alpha_
# Test smaller alpha makes coefs nonzero.
reg = ElasticNet(alpha=0.99 * alpha_max, fit_intercept=fit_intercept)
reg.fit(X, y, sample_weight=sample_weight)
assert_array_less(1e-3, np.max(np.abs(reg.coef_)))


@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("l1_ratio", [0, 0.5, 1])
@pytest.mark.parametrize("precompute", [False, True])
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.