diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 2dbb83c82fbaa..3305a368b5245 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -910,7 +910,7 @@ def __init__( warm_start=False, positive=False, random_state=None, - selection="cyclic", + selection="cyclic" ): self.alpha = alpha self.l1_ratio = l1_ratio @@ -2052,7 +2052,9 @@ def __init__( positive=False, random_state=None, selection="cyclic", + refit=True, ): + self.refit = refit super().__init__( eps=eps, n_alphas=n_alphas, @@ -2082,6 +2084,7 @@ def __sklearn_tags__(self): return tags def fit(self, X, y, sample_weight=None, **params): + params.pop('refit', None) """Fit Lasso model with coordinate descent. Fit is on grid of alphas and best alpha estimated by cross-validation. @@ -2119,8 +2122,13 @@ def fit(self, X, y, sample_weight=None, **params): self : object Returns an instance of fitted model. """ + return super().fit(X, y, sample_weight=sample_weight, **params) + if self.refit: + self._fit(X, y) + + return self class ElasticNetCV(RegressorMixin, LinearModelCV): """Elastic Net model with iterative fitting along a regularization path.