Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8c508c4

Browse filesBrowse files
ogriselCopilotthomasjpfan
authored
Fix do not recommend to increase max_iter in ConvergenceWarning when not appropriate (scikit-learn#31316)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent c1e6494 commit 8c508c4
Copy full SHA for 8c508c4

File tree

6 files changed

+95
-17
lines changed
Filter options

6 files changed

+95
-17
lines changed
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- Change the `ConvergenceWarning` message of estimators that rely on the
2+
`"lbfgs"` optimizer internally to be more informative and to avoid
3+
suggesting to increase the maximum number of iterations when it is not
4+
user-settable or when the convergence problem happens before reaching it.
5+
By :user:`Olivier Grisel <ogrisel>`.

‎sklearn/linear_model/_glm/_newton_solver.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_glm/_newton_solver.py
+3-2Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,22 @@ def fallback_lbfgs_solve(self, X, y, sample_weight):
178178
- self.coef
179179
- self.converged
180180
"""
181+
max_iter = self.max_iter - self.iteration
181182
opt_res = scipy.optimize.minimize(
182183
self.linear_loss.loss_gradient,
183184
self.coef,
184185
method="L-BFGS-B",
185186
jac=True,
186187
options={
187-
"maxiter": self.max_iter - self.iteration,
188+
"maxiter": max_iter,
188189
"maxls": 50, # default is 20
189190
"iprint": self.verbose - 1,
190191
"gtol": self.tol,
191192
"ftol": 64 * np.finfo(np.float64).eps,
192193
},
193194
args=(X, y, sample_weight, self.l2_reg_strength, self.n_threads),
194195
)
195-
self.iteration += _check_optimize_result("lbfgs", opt_res)
196+
self.iteration += _check_optimize_result("lbfgs", opt_res, max_iter=max_iter)
196197
self.coef = opt_res.x
197198
self.converged = opt_res.status == 0
198199

‎sklearn/linear_model/_glm/glm.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_glm/glm.py
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,9 @@ def fit(self, X, y, sample_weight=None):
282282
},
283283
args=(X, y, sample_weight, l2_reg_strength, n_threads),
284284
)
285-
self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
285+
self.n_iter_ = _check_optimize_result(
286+
"lbfgs", opt_res, max_iter=self.max_iter
287+
)
286288
coef = opt_res.x
287289
elif self.solver == "newton-cholesky":
288290
sol = NewtonCholeskySolver(

‎sklearn/linear_model/tests/test_logistic.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_logistic.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def test_logistic_regression_path_convergence_fail():
444444

445445
assert len(record) == 1
446446
warn_msg = record[0].message.args[0]
447-
assert "lbfgs failed to converge" in warn_msg
447+
assert "lbfgs failed to converge after 1 iteration(s)" in warn_msg
448448
assert "Increase the number of iterations" in warn_msg
449449
assert "scale the data" in warn_msg
450450
assert "linear_model.html#logistic-regression" in warn_msg

‎sklearn/utils/optimize.py

Copy file name to clipboardExpand all lines: sklearn/utils/optimize.py
+24-12Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -352,25 +352,37 @@ def _check_optimize_result(solver, result, max_iter=None, extra_warning_msg=None
352352
"""
353353
# handle both scipy and scikit-learn solver names
354354
if solver == "lbfgs":
355-
if result.status != 0:
356-
result_message = result.message
355+
if max_iter is not None:
356+
# In scipy <= 1.0.0, nit may exceed maxiter for lbfgs.
357+
# See https://github.com/scipy/scipy/issues/7854
358+
n_iter_i = min(result.nit, max_iter)
359+
else:
360+
n_iter_i = result.nit
357361

362+
if result.status != 0:
358363
warning_msg = (
359-
"{} failed to converge (status={}):\n{}.\n\n"
360-
"Increase the number of iterations (max_iter) "
361-
"or scale the data as shown in:\n"
364+
f"{solver} failed to converge after {n_iter_i} iteration(s) "
365+
f"(status={result.status}):\n"
366+
f"{result.message}\n"
367+
)
368+
# Append a recommendation to increase iterations only when the
369+
# number of iterations reaches the maximum allowed (max_iter),
370+
# as this suggests the optimization may have been prematurely
371+
# terminated due to the iteration limit.
372+
if max_iter is not None and n_iter_i == max_iter:
373+
warning_msg += (
374+
f"\nIncrease the number of iterations to improve the "
375+
f"convergence (max_iter={max_iter})."
376+
)
377+
warning_msg += (
378+
"\nYou might also want to scale the data as shown in:\n"
362379
" https://scikit-learn.org/stable/modules/"
363380
"preprocessing.html"
364-
).format(solver, result.status, result_message)
381+
)
365382
if extra_warning_msg is not None:
366383
warning_msg += "\n" + extra_warning_msg
367384
warnings.warn(warning_msg, ConvergenceWarning, stacklevel=2)
368-
if max_iter is not None:
369-
# In scipy <= 1.0.0, nit may exceed maxiter for lbfgs.
370-
# See https://github.com/scipy/scipy/issues/7854
371-
n_iter_i = min(result.nit, max_iter)
372-
else:
373-
n_iter_i = result.nit
385+
374386
else:
375387
raise NotImplementedError
376388

‎sklearn/utils/tests/test_optimize.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_optimize.py
+59-1Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import warnings
2+
13
import numpy as np
24
import pytest
35
from scipy.optimize import fmin_ncg
46

57
from sklearn.exceptions import ConvergenceWarning
8+
from sklearn.utils._bunch import Bunch
69
from sklearn.utils._testing import assert_allclose
7-
from sklearn.utils.optimize import _newton_cg
10+
from sklearn.utils.optimize import _check_optimize_result, _newton_cg
811

912

1013
def test_newton_cg(global_random_seed):
@@ -160,3 +163,58 @@ def test_newton_cg_verbosity(capsys, verbose):
160163
]
161164
for m in msg:
162165
assert m in captured.out
166+
167+
168+
def test_check_optimize():
169+
# Mock some lbfgs output using a Bunch instance:
170+
result = Bunch()
171+
172+
# First case: no warnings
173+
result.nit = 1
174+
result.status = 0
175+
result.message = "OK"
176+
177+
with warnings.catch_warnings():
178+
warnings.simplefilter("error")
179+
_check_optimize_result("lbfgs", result)
180+
181+
# Second case: warning about implicit `max_iter`: do not recommend the user
182+
# to increase `max_iter` this is not a user settable parameter.
183+
result.status = 1
184+
result.message = "STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT"
185+
with pytest.warns(ConvergenceWarning) as record:
186+
_check_optimize_result("lbfgs", result)
187+
188+
assert len(record) == 1
189+
warn_msg = record[0].message.args[0]
190+
assert "lbfgs failed to converge after 1 iteration(s)" in warn_msg
191+
assert result.message in warn_msg
192+
assert "Increase the number of iterations" not in warn_msg
193+
assert "scale the data" in warn_msg
194+
195+
# Third case: warning about explicit `max_iter`: recommend user to increase
196+
# `max_iter`.
197+
with pytest.warns(ConvergenceWarning) as record:
198+
_check_optimize_result("lbfgs", result, max_iter=1)
199+
200+
assert len(record) == 1
201+
warn_msg = record[0].message.args[0]
202+
assert "lbfgs failed to converge after 1 iteration(s)" in warn_msg
203+
assert result.message in warn_msg
204+
assert "Increase the number of iterations" in warn_msg
205+
assert "scale the data" in warn_msg
206+
207+
# Fourth case: other convergence problem before reaching `max_iter`: do not
208+
# recommend increasing `max_iter`.
209+
result.nit = 2
210+
result.status = 2
211+
result.message = "ABNORMAL"
212+
with pytest.warns(ConvergenceWarning) as record:
213+
_check_optimize_result("lbfgs", result, max_iter=10)
214+
215+
assert len(record) == 1
216+
warn_msg = record[0].message.args[0]
217+
assert "lbfgs failed to converge after 2 iteration(s)" in warn_msg
218+
assert result.message in warn_msg
219+
assert "Increase the number of iterations" not in warn_msg
220+
assert "scale the data" in warn_msg

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.