Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c367698

Browse filesBrowse files
committed
FIX _parameter_constraints in SGD classes
1 parent 306a5fa commit c367698
Copy full SHA for c367698

File tree

Expand file treeCollapse file tree

1 file changed

+19
-10
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+19
-10
lines changed

‎sklearn/linear_model/_stochastic_gradient.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_stochastic_gradient.py
+19-10Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,12 @@ class BaseSGD(SparseCoefMixin, BaseEstimator, metaclass=ABCMeta):
8080
"""Base class for SGD classification and regression."""
8181

8282
_parameter_constraints: dict = {
83-
"penalty": [StrOptions({"l2", "l1", "elasticnet"}), None],
84-
"alpha": [Interval(Real, 0, None, closed="left")],
85-
"C": [Interval(Real, 0, None, closed="right")],
86-
"l1_ratio": [Interval(Real, 0, 1, closed="both")],
8783
"fit_intercept": ["boolean"],
8884
"max_iter": [Interval(Integral, 1, None, closed="left")],
8985
"tol": [Interval(Real, 0, None, closed="left"), None],
9086
"shuffle": ["boolean"],
91-
"random_state": ["random_state"],
9287
"verbose": ["verbose"],
93-
"eta0": [Interval(Real, 0, None, closed="left")],
94-
"power_t": [Interval(Real, None, None, closed="neither")],
95-
"early_stopping": ["boolean"],
96-
"validation_fraction": [Interval(Real, 0, 1, closed="neither")],
97-
"n_iter_no_change": [Interval(Integral, 1, None, closed="left")],
88+
"random_state": ["random_state"],
9889
"warm_start": ["boolean"],
9990
"average": [Interval(Integral, 0, None, closed="left"), "boolean"],
10091
}
@@ -523,6 +514,9 @@ class BaseSGDClassifier(LinearClassifierMixin, BaseSGD, metaclass=ABCMeta):
523514
_parameter_constraints: dict = {
524515
**BaseSGD._parameter_constraints,
525516
"loss": [StrOptions(set(loss_functions))],
517+
"early_stopping": ["boolean"],
518+
"validation_fraction": [Interval(Real, 0, 1, closed="neither")],
519+
"n_iter_no_change": [Interval(Integral, 1, None, closed="left")],
526520
"n_jobs": [Integral, None],
527521
"class_weight": [StrOptions({"balanced"}), dict, None],
528522
}
@@ -1214,11 +1208,16 @@ class SGDClassifier(BaseSGDClassifier):
12141208

12151209
_parameter_constraints: dict = {
12161210
**BaseSGDClassifier._parameter_constraints,
1211+
"penalty": [StrOptions({"l2", "l1", "elasticnet"}), None],
1212+
"alpha": [Interval(Real, 0, None, closed="left")],
1213+
"l1_ratio": [Interval(Real, 0, 1, closed="both")],
1214+
"power_t": [Interval(Real, None, None, closed="neither")],
12171215
"epsilon": [Interval(Real, 0, None, closed="left")],
12181216
"learning_rate": [
12191217
StrOptions({"constant", "optimal", "invscaling", "adaptive"}),
12201218
Hidden(StrOptions({"pa1", "pa2"})),
12211219
],
1220+
"eta0": [Interval(Real, 0, None, closed="left")],
12221221
}
12231222

12241223
def __init__(
@@ -1405,6 +1404,9 @@ class BaseSGDRegressor(RegressorMixin, BaseSGD):
14051404
_parameter_constraints: dict = {
14061405
**BaseSGD._parameter_constraints,
14071406
"loss": [StrOptions(set(loss_functions))],
1407+
"early_stopping": ["boolean"],
1408+
"validation_fraction": [Interval(Real, 0, 1, closed="neither")],
1409+
"n_iter_no_change": [Interval(Integral, 1, None, closed="left")],
14081410
}
14091411

14101412
@abstractmethod
@@ -2014,11 +2016,16 @@ class SGDRegressor(BaseSGDRegressor):
20142016

20152017
_parameter_constraints: dict = {
20162018
**BaseSGDRegressor._parameter_constraints,
2019+
"penalty": [StrOptions({"l2", "l1", "elasticnet"}), None],
2020+
"alpha": [Interval(Real, 0, None, closed="left")],
2021+
"l1_ratio": [Interval(Real, 0, 1, closed="both")],
2022+
"power_t": [Interval(Real, None, None, closed="neither")],
20172023
"learning_rate": [
20182024
StrOptions({"constant", "optimal", "invscaling", "adaptive"}),
20192025
Hidden(StrOptions({"pa1", "pa2"})),
20202026
],
20212027
"epsilon": [Interval(Real, 0, None, closed="left")],
2028+
"eta0": [Interval(Real, 0, None, closed="left")],
20222029
}
20232030

20242031
def __init__(
@@ -2228,6 +2235,8 @@ class SGDOneClassSVM(BaseSGD, OutlierMixin):
22282235
StrOptions({"constant", "optimal", "invscaling", "adaptive"}),
22292236
Hidden(StrOptions({"pa1", "pa2"})),
22302237
],
2238+
"eta0": [Interval(Real, 0, None, closed="left")],
2239+
"power_t": [Interval(Real, None, None, closed="neither")],
22312240
}
22322241

22332242
def __init__(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.