Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f0c80e8

Browse filesBrowse files
authored
MNT clean-up deprecations for 1.7: multi_class in LogisticRegression (scikit-learn#31241)
1 parent b985df0 commit f0c80e8
Copy full SHA for f0c80e8

File tree

9 files changed

+112
-77
lines changed
Filter options

9 files changed

+112
-77
lines changed

‎doc/modules/model_evaluation.rst

Copy file name to clipboardExpand all lines: doc/modules/model_evaluation.rst
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,7 +1632,7 @@ Therefore, the `y_score` parameter is of size (n_samples,).
16321632
>>> from sklearn.linear_model import LogisticRegression
16331633
>>> from sklearn.metrics import roc_auc_score
16341634
>>> X, y = load_breast_cancer(return_X_y=True)
1635-
>>> clf = LogisticRegression(solver="liblinear").fit(X, y)
1635+
>>> clf = LogisticRegression().fit(X, y)
16361636
>>> clf.classes_
16371637
array([0, 1])
16381638

@@ -1728,11 +1728,11 @@ class with the greater label for each output.
17281728
>>> from sklearn.datasets import make_multilabel_classification
17291729
>>> from sklearn.multioutput import MultiOutputClassifier
17301730
>>> X, y = make_multilabel_classification(random_state=0)
1731-
>>> inner_clf = LogisticRegression(solver="liblinear", random_state=0)
1731+
>>> inner_clf = LogisticRegression(random_state=0)
17321732
>>> clf = MultiOutputClassifier(inner_clf).fit(X, y)
17331733
>>> y_score = np.transpose([y_pred[:, 1] for y_pred in clf.predict_proba(X)])
17341734
>>> roc_auc_score(y, y_score, average=None)
1735-
array([0.82..., 0.86..., 0.94..., 0.85... , 0.94...])
1735+
array([0.82..., 0.85..., 0.93..., 0.86..., 0.94...])
17361736

17371737
And the decision values do not require such processing.
17381738

+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
- Using the `"liblinear"` solver for multiclass classification with a one-versus-rest
2+
scheme in :class:`linear_model.LogisticRegression` and
3+
:class:`linear_model.LogisticRegressionCV` is deprecated and will raise an error in
4+
version 1.8. Either use a solver which supports the multinomial loss or wrap the
5+
estimator in a :class:`sklearn.multiclass.OneVsRestClassifier` to keep applying a
6+
one-versus-rest scheme.
7+
By :user:`Jérémie du Boisberranger <jeremiedbb>`.

‎sklearn/ensemble/tests/test_voting.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/tests/test_voting.py
+5-5Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_notfitted():
114114

115115
def test_majority_label_iris(global_random_seed):
116116
"""Check classification by majority label on dataset iris."""
117-
clf1 = LogisticRegression(solver="liblinear", random_state=global_random_seed)
117+
clf1 = LogisticRegression(random_state=global_random_seed)
118118
clf2 = RandomForestClassifier(n_estimators=10, random_state=global_random_seed)
119119
clf3 = GaussianNB()
120120
eclf = VotingClassifier(
@@ -127,12 +127,12 @@ def test_majority_label_iris(global_random_seed):
127127

128128
def test_tie_situation():
129129
"""Check voting classifier selects smaller class label in tie situation."""
130-
clf1 = LogisticRegression(random_state=123, solver="liblinear")
130+
clf1 = LogisticRegression(random_state=123)
131131
clf2 = RandomForestClassifier(random_state=123)
132132
eclf = VotingClassifier(estimators=[("lr", clf1), ("rf", clf2)], voting="hard")
133-
assert clf1.fit(X, y).predict(X)[73] == 2
134-
assert clf2.fit(X, y).predict(X)[73] == 1
135-
assert eclf.fit(X, y).predict(X)[73] == 1
133+
assert clf1.fit(X, y).predict(X)[52] == 2
134+
assert clf2.fit(X, y).predict(X)[52] == 1
135+
assert eclf.fit(X, y).predict(X)[52] == 1
136136

137137

138138
def test_weights_iris(global_random_seed):

‎sklearn/linear_model/_logistic.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_logistic.py
+22-4Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,15 @@ def _logistic_regression_path(
501501
w0 = sol.solve(X=X, y=target, sample_weight=sample_weight)
502502
n_iter_i = sol.iteration
503503
elif solver == "liblinear":
504+
if len(classes) > 2:
505+
warnings.warn(
506+
"Using the 'liblinear' solver for multiclass classification is "
507+
"deprecated. An error will be raised in 1.8. Either use another "
508+
"solver which supports the multinomial loss or wrap the estimator "
509+
"in a OneVsRestClassifier to keep applying a one-versus-rest "
510+
"scheme.",
511+
FutureWarning,
512+
)
504513
(
505514
coef_,
506515
intercept_,
@@ -931,7 +940,7 @@ class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator):
931940
'lbfgs' 'l2', None yes
932941
'liblinear' 'l1', 'l2' no
933942
'newton-cg' 'l2', None yes
934-
'newton-cholesky' 'l2', None no
943+
'newton-cholesky' 'l2', None yes
935944
'sag' 'l2', None yes
936945
'saga' 'elasticnet', 'l1', 'l2', None yes
937946
================= ============================== ======================
@@ -1238,7 +1247,7 @@ def fit(self, X, y, sample_weight=None):
12381247
check_classification_targets(y)
12391248
self.classes_ = np.unique(y)
12401249

1241-
# TODO(1.7) remove multi_class
1250+
# TODO(1.8) remove multi_class
12421251
multi_class = self.multi_class
12431252
if self.multi_class == "multinomial" and len(self.classes_) == 2:
12441253
warnings.warn(
@@ -1274,6 +1283,15 @@ def fit(self, X, y, sample_weight=None):
12741283
multi_class = _check_multi_class(multi_class, solver, len(self.classes_))
12751284

12761285
if solver == "liblinear":
1286+
if len(self.classes_) > 2:
1287+
warnings.warn(
1288+
"Using the 'liblinear' solver for multiclass classification is "
1289+
"deprecated. An error will be raised in 1.8. Either use another "
1290+
"solver which supports the multinomial loss or wrap the estimator "
1291+
"in a OneVsRestClassifier to keep applying a one-versus-rest "
1292+
"scheme.",
1293+
FutureWarning,
1294+
)
12771295
if effective_n_jobs(self.n_jobs) != 1:
12781296
warnings.warn(
12791297
"'n_jobs' > 1 does not have any effect when"
@@ -1568,7 +1586,7 @@ class LogisticRegressionCV(LogisticRegression, LinearClassifierMixin, BaseEstima
15681586
'lbfgs' 'l2' yes
15691587
'liblinear' 'l1', 'l2' no
15701588
'newton-cg' 'l2' yes
1571-
'newton-cholesky' 'l2', no
1589+
'newton-cholesky' 'l2', yes
15721590
'sag' 'l2', yes
15731591
'saga' 'elasticnet', 'l1', 'l2' yes
15741592
================= ============================== ======================
@@ -1900,7 +1918,7 @@ def fit(self, X, y, sample_weight=None, **params):
19001918
classes = self.classes_ = label_encoder.classes_
19011919
encoded_labels = label_encoder.transform(label_encoder.classes_)
19021920

1903-
# TODO(1.7) remove multi_class
1921+
# TODO(1.8) remove multi_class
19041922
multi_class = self.multi_class
19051923
if self.multi_class == "multinomial" and len(self.classes_) == 2:
19061924
warnings.warn(

‎sklearn/linear_model/tests/test_logistic.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_logistic.py
+42-19Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ def __call__(self, model, X, y, sample_weight=None):
129129

130130
@skip_if_no_parallel
131131
def test_lr_liblinear_warning():
132-
n_samples, n_features = iris.data.shape
133-
target = iris.target_names[iris.target]
132+
X, y = make_classification(random_state=0)
134133

135134
lr = LogisticRegression(solver="liblinear", n_jobs=2)
136135
warning_message = (
@@ -139,7 +138,7 @@ def test_lr_liblinear_warning():
139138
" = 2."
140139
)
141140
with pytest.warns(UserWarning, match=warning_message):
142-
lr.fit(iris.data, target)
141+
lr.fit(X, y)
143142

144143

145144
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
@@ -148,8 +147,11 @@ def test_predict_3_classes(csr_container):
148147
check_predictions(LogisticRegression(C=10), csr_container(X), Y2)
149148

150149

151-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
150+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
152151
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
152+
@pytest.mark.filterwarnings(
153+
"ignore:.*'liblinear' solver for multiclass classification is deprecated.*"
154+
)
153155
@pytest.mark.parametrize(
154156
"clf",
155157
[
@@ -197,7 +199,7 @@ def test_predict_iris(clf):
197199
assert np.mean(pred == target) > 0.95
198200

199201

200-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
202+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
201203
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
202204
@pytest.mark.parametrize("LR", [LogisticRegression, LogisticRegressionCV])
203205
def test_check_solver_option(LR):
@@ -249,7 +251,7 @@ def test_elasticnet_l1_ratio_err_helpful(LR):
249251
model.fit(np.array([[1, 2], [3, 4]]), np.array([0, 1]))
250252

251253

252-
# TODO(1.7): remove whole test with deprecation of multi_class
254+
# TODO(1.8): remove whole test with deprecation of multi_class
253255
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
254256
@pytest.mark.parametrize("solver", ["lbfgs", "newton-cg", "sag", "saga"])
255257
def test_multinomial_binary(solver):
@@ -274,7 +276,7 @@ def test_multinomial_binary(solver):
274276
assert np.mean(pred == target) > 0.9
275277

276278

277-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
279+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
278280
# Maybe even remove this whole test as correctness of multinomial loss is tested
279281
# elsewhere.
280282
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
@@ -614,7 +616,7 @@ def test_logistic_cv_sparse(csr_container):
614616
assert clfs.C_ == clf.C_
615617

616618

617-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
619+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
618620
# Best remove this whole test.
619621
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
620622
def test_ovr_multinomial_iris():
@@ -700,7 +702,7 @@ def test_logistic_regression_solvers():
700702
)
701703

702704

703-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
705+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
704706
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
705707
@pytest.mark.parametrize("fit_intercept", [False, True])
706708
def test_logistic_regression_solvers_multiclass(fit_intercept):
@@ -1301,7 +1303,7 @@ def test_logreg_predict_proba_multinomial():
13011303
assert clf_wrong_loss > clf_multi_loss
13021304

13031305

1304-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
1306+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
13051307
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
13061308
@pytest.mark.parametrize("max_iter", np.arange(1, 5))
13071309
@pytest.mark.parametrize("multi_class", ["ovr", "multinomial"])
@@ -1345,8 +1347,11 @@ def test_max_iter(max_iter, multi_class, solver, message):
13451347
assert lr.n_iter_[0] == max_iter
13461348

13471349

1348-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
1350+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
13491351
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
1352+
@pytest.mark.filterwarnings(
1353+
"ignore:.*'liblinear' solver for multiclass classification is deprecated.*"
1354+
)
13501355
@pytest.mark.parametrize("solver", SOLVERS)
13511356
def test_n_iter(solver):
13521357
# Test that self.n_iter_ has the correct format.
@@ -1478,7 +1483,7 @@ def test_saga_vs_liblinear(csr_container):
14781483
assert_array_almost_equal(saga.coef_, liblinear.coef_, 3)
14791484

14801485

1481-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
1486+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
14821487
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
14831488
@pytest.mark.parametrize("multi_class", ["ovr", "multinomial"])
14841489
@pytest.mark.parametrize(
@@ -1738,7 +1743,7 @@ def test_LogisticRegressionCV_GridSearchCV_elastic_net(n_classes):
17381743
assert gs.best_params_["C"] == lrcv.C_[0]
17391744

17401745

1741-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
1746+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
17421747
# Maybe remove whole test after removal of the deprecated multi_class.
17431748
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
17441749
def test_LogisticRegressionCV_GridSearchCV_elastic_net_ovr():
@@ -1786,7 +1791,7 @@ def test_LogisticRegressionCV_GridSearchCV_elastic_net_ovr():
17861791
assert (lrcv.predict(X_test) == gs.predict(X_test)).mean() >= 0.8
17871792

17881793

1789-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
1794+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
17901795
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
17911796
@pytest.mark.parametrize("penalty", ("l2", "elasticnet"))
17921797
@pytest.mark.parametrize("multi_class", ("ovr", "multinomial", "auto"))
@@ -1825,7 +1830,7 @@ def test_LogisticRegressionCV_no_refit(penalty, multi_class):
18251830
assert lrcv.coef_.shape == (n_classes, n_features)
18261831

18271832

1828-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
1833+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
18291834
# Remove multi_class an change first element of the expected n_iter_.shape from
18301835
# n_classes to 1 (according to the docstring).
18311836
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
@@ -1955,8 +1960,11 @@ def test_logistic_regression_path_coefs_multinomial():
19551960
assert_array_almost_equal(coefs[1], coefs[2], decimal=1)
19561961

19571962

1958-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
1963+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
19591964
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
1965+
@pytest.mark.filterwarnings(
1966+
"ignore:.*'liblinear' solver for multiclass classification is deprecated.*"
1967+
)
19601968
@pytest.mark.parametrize(
19611969
"est",
19621970
[
@@ -2126,7 +2134,7 @@ def test_scores_attribute_layout_elasticnet():
21262134
assert avg_scores_lrcv[i, j] == pytest.approx(avg_score_lr)
21272135

21282136

2129-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
2137+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
21302138
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
21312139
@pytest.mark.parametrize("solver", ["lbfgs", "newton-cg", "newton-cholesky"])
21322140
@pytest.mark.parametrize("fit_intercept", [False, True])
@@ -2171,7 +2179,7 @@ def test_multinomial_identifiability_on_iris(solver, fit_intercept):
21712179
assert clf.intercept_.sum(axis=0) == pytest.approx(0, abs=1e-11)
21722180

21732181

2174-
# TODO(1.7): remove filterwarnings after the deprecation of multi_class
2182+
# TODO(1.8): remove filterwarnings after the deprecation of multi_class
21752183
@pytest.mark.filterwarnings("ignore:.*'multi_class' was deprecated.*:FutureWarning")
21762184
@pytest.mark.parametrize("multi_class", ["ovr", "multinomial", "auto"])
21772185
@pytest.mark.parametrize("class_weight", [{0: 1.0, 1: 10.0, 2: 1.0}, "balanced"])
@@ -2349,7 +2357,7 @@ def test_passing_params_without_enabling_metadata_routing():
23492357
lr_cv.score(X, y, **params)
23502358

23512359

2352-
# TODO(1.7): remove
2360+
# TODO(1.8): remove
23532361
def test_multi_class_deprecated():
23542362
"""Check `multi_class` parameter deprecated."""
23552363
X, y = make_classification(n_classes=3, n_samples=50, n_informative=6)
@@ -2414,3 +2422,18 @@ def test_newton_cholesky_fallback_to_lbfgs(global_random_seed):
24142422
n_iter_nc_limited = lr_nc_limited.n_iter_[0]
24152423

24162424
assert n_iter_nc_limited == lr_nc_limited.max_iter - 1
2425+
2426+
2427+
# TODO(1.8): check for an error instead
2428+
@pytest.mark.parametrize("Estimator", [LogisticRegression, LogisticRegressionCV])
2429+
def test_liblinear_multiclass_warning(Estimator):
2430+
"""Check that liblinear warns on multiclass problems."""
2431+
msg = (
2432+
"Using the 'liblinear' solver for multiclass classification is "
2433+
"deprecated. An error will be raised in 1.8. Either use another "
2434+
"solver which supports the multinomial loss or wrap the estimator "
2435+
"in a OneVsRestClassifier to keep applying a one-versus-rest "
2436+
"scheme."
2437+
)
2438+
with pytest.warns(FutureWarning, match=msg):
2439+
Estimator(solver="liblinear").fit(iris.data, iris.target)

‎sklearn/metrics/_ranking.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_ranking.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ class scores must correspond to the order of ``labels``,
622622
>>> from sklearn.linear_model import LogisticRegression
623623
>>> from sklearn.metrics import roc_auc_score
624624
>>> X, y = load_breast_cancer(return_X_y=True)
625-
>>> clf = LogisticRegression(solver="liblinear", random_state=0).fit(X, y)
625+
>>> clf = LogisticRegression(solver="newton-cholesky", random_state=0).fit(X, y)
626626
>>> roc_auc_score(y, clf.predict_proba(X)[:, 1])
627627
0.99...
628628
>>> roc_auc_score(y, clf.decision_function(X))
@@ -632,7 +632,7 @@ class scores must correspond to the order of ``labels``,
632632
633633
>>> from sklearn.datasets import load_iris
634634
>>> X, y = load_iris(return_X_y=True)
635-
>>> clf = LogisticRegression(solver="liblinear").fit(X, y)
635+
>>> clf = LogisticRegression(solver="newton-cholesky").fit(X, y)
636636
>>> roc_auc_score(y, clf.predict_proba(X), multi_class='ovr')
637637
0.99...
638638
@@ -649,7 +649,7 @@ class scores must correspond to the order of ``labels``,
649649
>>> # extract the positive columns for each output
650650
>>> y_score = np.transpose([score[:, 1] for score in y_score])
651651
>>> roc_auc_score(y, y_score, average=None)
652-
array([0.82..., 0.86..., 0.94..., 0.85... , 0.94...])
652+
array([0.82..., 0.85..., 0.93..., 0.86..., 0.94...])
653653
>>> from sklearn.linear_model import RidgeClassifierCV
654654
>>> clf = RidgeClassifierCV().fit(X, y)
655655
>>> roc_auc_score(y, clf.decision_function(X), average=None)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.