Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 58f80da

Browse filesBrowse files
MAINT LogisticRegression informative error msg when penaly=elasticnet and l1_ratio is None (#25925)
Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent 4751545 commit 58f80da
Copy full SHA for 58f80da

File tree

Expand file treeCollapse file tree

2 files changed

+12
-0
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+12
-0
lines changed

‎sklearn/linear_model/_logistic.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_logistic.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,9 @@ def fit(self, X, y, sample_weight=None):
11681168
"(penalty={})".format(self.penalty)
11691169
)
11701170

1171+
if self.penalty == "elasticnet" and self.l1_ratio is None:
1172+
raise ValueError("l1_ratio must be specified when penalty is elasticnet.")
1173+
11711174
# TODO(1.4): Remove "none" option
11721175
if self.penalty == "none":
11731176
warnings.warn(

‎sklearn/linear_model/tests/test_logistic.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_logistic.py
+9Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ def test_check_solver_option(LR):
227227
lr.fit(X, y)
228228

229229

230+
@pytest.mark.parametrize("LR", [LogisticRegression, LogisticRegressionCV])
231+
def test_elasticnet_l1_ratio_err_helpful(LR):
232+
# Check that an informative error message is raised when penalty="elasticnet"
233+
# but l1_ratio is not specified.
234+
model = LR(penalty="elasticnet", solver="saga")
235+
with pytest.raises(ValueError, match=r".*l1_ratio.*"):
236+
model.fit(np.array([[1, 2], [3, 4]]), np.array([0, 1]))
237+
238+
230239
@pytest.mark.parametrize("solver", ["lbfgs", "newton-cg", "sag", "saga"])
231240
def test_multinomial_binary(solver):
232241
# Test multinomial LR on a binary problem.

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.