Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 306826f

Browse filesBrowse files
maikiaglemaitreNicolasHugagramfortogrisel
authored
MRG Deprecates 'normalize' in LinearRegression (_base.py) (#17743)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 8ea176a commit 306826f
Copy full SHA for 306826f

File tree

4 files changed

+260
-6
lines changed
Filter options

4 files changed

+260
-6
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+12Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,18 @@ Changelog
9494
Use ``var_`` instead.
9595
:pr:`18842` by :user:`Hong Shao Yang <hongshaoyang>`.
9696

97+
- |API|: The parameter ``normalize`` of :class:`linear_model.LinearRegression`
98+
is deprecated and will be removed in 1.2.
99+
Motivation for this deprecation: ``normalize`` parameter did not take any
100+
effect if ``fit_intercept`` was set to False and therefore was deemed
101+
confusing.
102+
The behavior of the deprecated LinearRegression(normalize=True) can be
103+
reproduced with :class:`~sklearn.pipeline.Pipeline` with
104+
:class:`~sklearn.preprocessing.StandardScaler`as follows:
105+
make_pipeline(StandardScaler(with_mean=False), LinearRegression()).
106+
:pr:`17743` by :user:`Maria Telenczuk <maikia>` and
107+
:user:`Alexandre Gramfort <agramfort>`.
108+
97109
Code and Documentation Contributors
98110
-----------------------------------
99111

‎sklearn/linear_model/_base.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_base.py
+105-5Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# Lars Buitinck
1212
# Maryan Morel <maryan.morel@polytechnique.edu>
1313
# Giorgio Patrini <giorgio.patrini@anu.edu.au>
14+
# Maria Telenczuk <https://github.com/maikia>
1415
# License: BSD 3 clause
1516

1617
from abc import ABCMeta, abstractmethod
@@ -49,6 +50,94 @@
4950
# intercept oscillation.
5051

5152

53+
# FIXME in 1.2: parameter 'normalize' should be removed from linear models
54+
# in cases where now normalize=False. The default value of 'normalize' should
55+
# be changed to False in linear models where now normalize=True
56+
def _deprecate_normalize(normalize, default, estimator_name):
57+
""" Normalize is to be deprecated from linear models and a use of
58+
a pipeline with a StandardScaler is to be recommended instead.
59+
Here the appropriate message is selected to be displayed to the user
60+
depending on the default normalize value (as it varies between the linear
61+
models and normalize value selected by the user).
62+
63+
Parameters
64+
----------
65+
normalize : bool,
66+
normalize value passed by the user
67+
68+
default : bool,
69+
default normalize value used by the estimator
70+
71+
estimator_name : string,
72+
name of the linear estimator which calls this function.
73+
The name will be used for writing the deprecation warnings
74+
75+
Returns
76+
-------
77+
normalize : bool,
78+
normalize value which should further be used by the estimator at this
79+
stage of the depreciation process
80+
81+
Notes
82+
-----
83+
This function should be updated in 1.2 depending on the value of
84+
`normalize`:
85+
- True, warning: `normalize` was deprecated in 1.2 and will be removed in
86+
1.4. Suggest to use pipeline instead.
87+
- False, `normalize` was deprecated in 1.2 and it will be removed in 1.4.
88+
Leave normalize to its default value.
89+
- `deprecated` - this should only be possible with default == False as from
90+
1.2 `normalize` in all the linear models should be either removed or the
91+
default should be set to False.
92+
This function should be completely removed in 1.4.
93+
"""
94+
95+
if normalize not in [True, False, 'deprecated']:
96+
raise ValueError("Leave 'normalize' to its default value or set it "
97+
"to True or False")
98+
99+
if normalize == 'deprecated':
100+
_normalize = default
101+
else:
102+
_normalize = normalize
103+
104+
if default and normalize == 'deprecated':
105+
warnings.warn(
106+
"The default of 'normalize' will be set to False in version 1.2 "
107+
"and deprecated in version 1.4. \nPass normalize=False and use "
108+
"Pipeline with a StandardScaler in a preprocessing stage if you "
109+
"wish to reproduce the previous behavior:\n"
110+
"model = make_pipeline(StandardScaler(with_mean=False), \n"
111+
f"{estimator_name}(normalize=False))\n"
112+
"If you wish to use additional parameters in "
113+
"the fit() you can include them as follows:\n"
114+
"kwargs = {model.steps[-1][0] + "
115+
"'__<your_param_name>': <your_param_value>}\n"
116+
"model.fit(X, y, **kwargs)", FutureWarning
117+
)
118+
elif normalize != 'deprecated' and normalize and not default:
119+
warnings.warn(
120+
"'normalize' was deprecated in version 1.0 and will be "
121+
"removed in 1.2 \nIf you still wish to normalize use "
122+
"Pipeline with a StandardScaler in a preprocessing stage if you "
123+
"wish to reproduce the previous behavior:\n"
124+
"model = make_pipeline(StandardScaler(with_mean=False), "
125+
f"{estimator_name}()). \nIf you wish to use additional "
126+
"parameters in the fit() you can include them as follows: "
127+
"kwargs = {model.steps[-1][0] + "
128+
"'__<your_param_name>': <your_param_value>}\n"
129+
"model.fit(X, y, **kwargs)", FutureWarning
130+
)
131+
elif not normalize and not default:
132+
warnings.warn(
133+
"'normalize' was deprecated in version 1.0 and will be"
134+
" removed in 1.2 Don't set 'normalize' parameter"
135+
" and leave it to its default value", FutureWarning
136+
)
137+
138+
return _normalize
139+
140+
52141
def make_dataset(X, y, sample_weight, random_state=None):
53142
"""Create ``Dataset`` abstraction for sparse and dense inputs.
54143
@@ -407,6 +496,10 @@ class LinearRegression(MultiOutputMixin, RegressorMixin, LinearModel):
407496
:class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
408497
on an estimator with ``normalize=False``.
409498
499+
.. deprecated:: 1.0
500+
`normalize` was deprecated in version 1.0 and will be
501+
removed in 1.2.
502+
410503
copy_X : bool, default=True
411504
If True, X will be copied; else, it may be overwritten.
412505
@@ -476,8 +569,8 @@ class LinearRegression(MultiOutputMixin, RegressorMixin, LinearModel):
476569
array([16.])
477570
"""
478571
@_deprecate_positional_args
479-
def __init__(self, *, fit_intercept=True, normalize=False, copy_X=True,
480-
n_jobs=None, positive=False):
572+
def __init__(self, *, fit_intercept=True, normalize='deprecated',
573+
copy_X=True, n_jobs=None, positive=False):
481574
self.fit_intercept = fit_intercept
482575
self.normalize = normalize
483576
self.copy_X = copy_X
@@ -507,6 +600,11 @@ def fit(self, X, y, sample_weight=None):
507600
self : returns an instance of self.
508601
"""
509602

603+
_normalize = _deprecate_normalize(
604+
self.normalize, default=False,
605+
estimator_name=self.__class__.__name__
606+
)
607+
510608
n_jobs_ = self.n_jobs
511609

512610
accept_sparse = False if self.positive else ['csr', 'csc', 'coo']
@@ -519,7 +617,7 @@ def fit(self, X, y, sample_weight=None):
519617
dtype=X.dtype)
520618

521619
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
522-
X, y, fit_intercept=self.fit_intercept, normalize=self.normalize,
620+
X, y, fit_intercept=self.fit_intercept, normalize=_normalize,
523621
copy=self.copy_X, sample_weight=sample_weight,
524622
return_mean=True)
525623

@@ -651,10 +749,12 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy,
651749
check_input=check_input, sample_weight=sample_weight)
652750
if sample_weight is not None:
653751
X, y = _rescale_data(X, y, sample_weight=sample_weight)
752+
753+
# FIXME: 'normalize' to be removed in 1.2
654754
if hasattr(precompute, '__array__'):
655755
if (fit_intercept and not np.allclose(X_offset, np.zeros(n_features))
656-
or normalize and not np.allclose(X_scale,
657-
np.ones(n_features))):
756+
or normalize and not np.allclose(X_scale, np.ones(n_features)
757+
)):
658758
warnings.warn(
659759
"Gram matrix was provided but X was centered to fit "
660760
"intercept, or X was normalized : recomputing Gram matrix.",

‎sklearn/linear_model/tests/test_base.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_base.py
+86-1Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn.utils.fixes import parse_version
1818

1919
from sklearn.linear_model import LinearRegression
20+
from sklearn.linear_model._base import _deprecate_normalize
2021
from sklearn.linear_model._base import _preprocess_data
2122
from sklearn.linear_model._base import _rescale_data
2223
from sklearn.linear_model._base import make_dataset
@@ -106,6 +107,7 @@ def test_raises_value_error_if_positive_and_sparse():
106107
with pytest.raises(TypeError, match=error_msg):
107108
reg.fit(X, y)
108109

110+
109111
def test_raises_value_error_if_sample_weights_greater_than_1d():
110112
# Sample weights must be either scalar or 1D
111113

@@ -149,6 +151,59 @@ def test_fit_intercept():
149151
lr3_without_intercept.coef_.ndim)
150152

151153

154+
def test_error_on_wrong_normalize():
155+
normalize = 'wrong'
156+
default = True
157+
error_msg = "Leave 'normalize' to its default"
158+
with pytest.raises(ValueError, match=error_msg):
159+
_deprecate_normalize(normalize, default, 'estimator')
160+
ValueError
161+
162+
163+
@pytest.mark.parametrize('normalize', [True, False, 'deprecated'])
164+
@pytest.mark.parametrize('default', [True, False])
165+
# FIXME update test in 1.2 for new versions
166+
def test_deprecate_normalize(normalize, default):
167+
# test all possible case of the normalize parameter deprecation
168+
if not default:
169+
if normalize == 'deprecated':
170+
# no warning
171+
output = default
172+
expected = None
173+
warning_msg = []
174+
else:
175+
output = normalize
176+
expected = FutureWarning
177+
warning_msg = ['1.2']
178+
if not normalize:
179+
warning_msg.append('default value')
180+
else:
181+
warning_msg.append('StandardScaler(')
182+
elif default:
183+
if normalize == 'deprecated':
184+
# warning to pass False and use StandardScaler
185+
output = default
186+
expected = FutureWarning
187+
warning_msg = ['False', '1.2', 'StandardScaler(']
188+
else:
189+
# no warning
190+
output = normalize
191+
expected = None
192+
warning_msg = []
193+
194+
with pytest.warns(expected) as record:
195+
_normalize = _deprecate_normalize(normalize, default, 'estimator')
196+
assert _normalize == output
197+
198+
n_warnings = 0 if expected is None else 1
199+
assert len(record) == n_warnings
200+
if n_warnings:
201+
assert all([
202+
warning in str(record[0].message)
203+
for warning in warning_msg
204+
])
205+
206+
152207
def test_linear_regression_sparse(random_state=0):
153208
# Test that linear regression also works with sparse data
154209
random_state = check_random_state(random_state)
@@ -165,6 +220,35 @@ def test_linear_regression_sparse(random_state=0):
165220
assert_array_almost_equal(ols.predict(X) - y.ravel(), 0)
166221

167222

223+
@pytest.mark.parametrize(
224+
'normalize, n_warnings, warning',
225+
[(True, 1, FutureWarning),
226+
(False, 1, FutureWarning),
227+
("deprecated", 0, None)]
228+
)
229+
# FIXME remove test in 1.4
230+
def test_linear_regression_normalize_deprecation(
231+
normalize, n_warnings, warning
232+
):
233+
# check that we issue a FutureWarning when normalize was set in
234+
# LinearRegression
235+
rng = check_random_state(0)
236+
n_samples = 200
237+
n_features = 2
238+
X = rng.randn(n_samples, n_features)
239+
X[X < 0.1] = 0.0
240+
y = rng.rand(n_samples)
241+
242+
model = LinearRegression(normalize=normalize)
243+
with pytest.warns(warning) as record:
244+
model.fit(X, y)
245+
assert len(record) == n_warnings
246+
if n_warnings:
247+
assert "'normalize' was deprecated" in str(record[0].message)
248+
249+
250+
# FIXME: 'normalize' to be removed in 1.2 in LinearRegression
251+
@pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
168252
@pytest.mark.parametrize('normalize', [True, False])
169253
@pytest.mark.parametrize('fit_intercept', [True, False])
170254
def test_linear_regression_sparse_equal_dense(normalize, fit_intercept):
@@ -303,8 +387,9 @@ def test_linear_regression_pd_sparse_dataframe_warning():
303387
df[str(col)] = arr
304388

305389
msg = "pandas.DataFrame with sparse columns found."
390+
391+
reg = LinearRegression()
306392
with pytest.warns(UserWarning, match=msg):
307-
reg = LinearRegression()
308393
reg.fit(df.iloc[:, 0:2], df.iloc[:, 3])
309394

310395
# does not warn when the whole dataframe is sparse

‎sklearn/linear_model/tests/test_coordinate_descent.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_coordinate_descent.py
+57Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sklearn.utils._testing import assert_warns_message
2727
from sklearn.utils._testing import ignore_warnings
2828
from sklearn.utils._testing import assert_array_equal
29+
from sklearn.utils._testing import _convert_container
2930
from sklearn.utils._testing import TempMemmap
3031
from sklearn.utils.fixes import parse_version
3132

@@ -301,6 +302,8 @@ def test_lasso_cv_positive_constraint():
301302
assert min(clf_constrained.coef_) >= 0
302303

303304

305+
# FIXME: 'normalize' to be removed in 1.2
306+
@pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
304307
@pytest.mark.parametrize(
305308
"LinearModel, params",
306309
[(Lasso, {"tol": 1e-16, "alpha": 0.1}),
@@ -384,6 +387,60 @@ def test_model_pipeline_same_as_normalize_true(LinearModel, params):
384387
assert_allclose(y_pred_normalize, y_pred_standardize)
385388

386389

390+
# FIXME: 'normalize' to be removed in 1.2
391+
@pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
392+
@pytest.mark.parametrize(
393+
"estimator, is_sparse, with_mean",
394+
[(LinearRegression, True, False),
395+
(LinearRegression, False, True),
396+
(LinearRegression, False, False)]
397+
)
398+
def test_linear_model_sample_weights_normalize_in_pipeline(
399+
estimator, is_sparse, with_mean
400+
):
401+
# Test that the results for running linear regression LinearRegression with
402+
# sample_weight set and with normalize set to True gives similar results as
403+
# LinearRegression with no normalize in a pipeline with a StandardScaler
404+
# and set sample_weight.
405+
rng = np.random.RandomState(0)
406+
X, y = make_regression(n_samples=20, n_features=5, noise=1e-2,
407+
random_state=rng)
408+
# make sure the data is not centered to make the problem more
409+
# difficult
410+
X += 10
411+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5,
412+
random_state=rng)
413+
if is_sparse:
414+
X_train = sparse.csr_matrix(X_train)
415+
X_test = _convert_container(X_train, 'sparse')
416+
417+
sample_weight = rng.rand(X_train.shape[0])
418+
419+
# linear estimator with explicit sample_weight
420+
reg_with_normalize = estimator(normalize=True)
421+
reg_with_normalize.fit(X_train, y_train, sample_weight=sample_weight)
422+
423+
# linear estimator in a pipeline
424+
reg_with_scaler = make_pipeline(
425+
StandardScaler(with_mean=with_mean),
426+
estimator(normalize=False)
427+
)
428+
kwargs = {reg_with_scaler.steps[-1][0] + '__sample_weight':
429+
sample_weight}
430+
reg_with_scaler.fit(X_train, y_train, **kwargs)
431+
432+
y_pred_norm = reg_with_normalize.predict(X_test)
433+
y_pred_pip = reg_with_scaler.predict(X_test)
434+
435+
assert_allclose(
436+
reg_with_normalize.coef_ * reg_with_scaler[0].scale_,
437+
reg_with_scaler[1].coef_
438+
)
439+
assert_allclose(y_pred_norm, y_pred_pip)
440+
441+
442+
# FIXME: 'normalize' to be removed in 1.2
443+
@pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
387444
@pytest.mark.parametrize(
388445
"LinearModel, params",
389446
[(Lasso, {"tol": 1e-16, "alpha": 0.1}),

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.