Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5a89f4d

Browse filesBrowse files
committed
Refactored get_params and set_params into a GetSetParamsMixin and
changed BaseEstimator and gaussian_process.kernels.Kernel to inherit from it. Also added InvalidParameterError to pass the parameter name and print the proper error message in set_params depending on the class.
1 parent c0601d7 commit 5a89f4d
Copy full SHA for 5a89f4d

File tree

3 files changed

+87
-71
lines changed
Filter options

3 files changed

+87
-71
lines changed

‎sklearn/base.py

Copy file name to clipboardExpand all lines: sklearn/base.py
+65-31Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414

1515
from . import __version__
16-
from .exceptions import SignatureError
16+
from .exceptions import InvalidParameterError, SignatureError
1717
from .utils import _IS_32BIT, get_param_names_from_constructor
1818

1919
_DEFAULT_TAGS = {
@@ -141,18 +141,9 @@ def _update_if_consistent(dict1, dict2):
141141
return dict1
142142

143143

144-
class BaseEstimator:
145-
"""Base class for all estimators in scikit-learn
146-
147-
Notes
148-
-----
149-
All estimators should specify all the parameters that can be set
150-
at the class level in their ``__init__`` as explicit keyword
151-
arguments (no ``*args`` or ``**kwargs``).
152-
"""
153-
144+
class GetSetParamsMixin:
154145
def _get_params_from(self, param_names, deep=True):
155-
"""Get parameters for this estimator from the given names"""
146+
"""Get parameters for this object from the given names"""
156147
out = dict()
157148
for key in param_names:
158149
value = getattr(self, key, None)
@@ -163,35 +154,26 @@ def _get_params_from(self, param_names, deep=True):
163154
return out
164155

165156
def get_params(self, deep=True):
166-
"""Get parameters for this estimator.
157+
"""Get parameters for this object.
167158
168159
Parameters
169160
----------
170161
deep : boolean, optional
171-
If True, will return the parameters for this estimator and
172-
contained subobjects that are estimators.
162+
If True, will return the parameters for this object and
163+
contained subobjects that implement ``get_params``.
173164
174165
Returns
175166
-------
176167
params : mapping of string to any
177168
Parameter names mapped to their values.
178169
"""
179-
try:
180-
param_names = get_param_names_from_constructor(self.__class__)
181-
except SignatureError as e:
182-
raise SignatureError("scikit-learn estimators should always "
183-
"specify their parameters in the signature"
184-
" of their __init__ (no varargs)."
185-
" %s with constructor %s doesn't "
186-
" follow this convention."
187-
% (self.__class__, e.signature))
188-
170+
param_names = get_param_names_from_constructor(self.__class__)
189171
return self._get_params_from(param_names, deep=deep)
190172

191173
def set_params(self, **params):
192-
"""Set the parameters of this estimator.
174+
"""Set the parameters of this object.
193175
194-
The method works on simple estimators as well as on nested objects
176+
The method works on simple objects as well as on nested objects
195177
(such as pipelines). The latter have parameters of the form
196178
``<component>__<parameter>`` so that it's possible to update each
197179
component of a nested object.
@@ -209,10 +191,7 @@ def set_params(self, **params):
209191
for key, value in params.items():
210192
key, delim, sub_key = key.partition('__')
211193
if key not in valid_params:
212-
raise ValueError('Invalid parameter %s for estimator %s. '
213-
'Check the list of available parameters '
214-
'with `estimator.get_params().keys()`.' %
215-
(key, self))
194+
raise InvalidParameterError(key)
216195

217196
if delim:
218197
nested_params[key][sub_key] = value
@@ -225,6 +204,61 @@ def set_params(self, **params):
225204

226205
return self
227206

207+
208+
class BaseEstimator(GetSetParamsMixin):
209+
"""Base class for all estimators in scikit-learn
210+
211+
Notes
212+
-----
213+
All estimators should specify all the parameters that can be set
214+
at the class level in their ``__init__`` as explicit keyword
215+
arguments (no ``*args`` or ``**kwargs``).
216+
"""
217+
218+
def get_params(self, deep=True):
219+
"""Get parameters for this estimator.
220+
221+
Parameters
222+
----------
223+
deep : boolean, optional
224+
If True, will return the parameters for this estimator and
225+
contained subobjects that are estimators.
226+
227+
Returns
228+
-------
229+
params : mapping of string to any
230+
Parameter names mapped to their values.
231+
"""
232+
try:
233+
return super().get_params(deep=deep)
234+
except SignatureError as e:
235+
raise SignatureError("scikit-learn estimators should always "
236+
"specify their parameters in the signature"
237+
" of their __init__ (no varargs)."
238+
" %s with constructor %s doesn't "
239+
" follow this convention."
240+
% (self.__class__, e.signature))
241+
242+
def set_params(self, **params):
243+
"""Set the parameters of this estimator.
244+
245+
The method works on simple estimators as well as on nested objects
246+
(such as pipelines). The latter have parameters of the form
247+
``<component>__<parameter>`` so that it's possible to update each
248+
component of a nested object.
249+
250+
Returns
251+
-------
252+
self
253+
"""
254+
try:
255+
return super().set_params(**params)
256+
except InvalidParameterError as e:
257+
raise InvalidParameterError('Invalid parameter %s for estimator %s. '
258+
'Check the list of available parameters '
259+
'with `estimator.get_params().keys()`.' %
260+
(e.param_name, self))
261+
228262
def __repr__(self, N_CHAR_MAX=700):
229263
# N_CHAR_MAX is the (approximate) maximum number of non-blank
230264
# characters to render. We pass it as an optional parameter to ease

‎sklearn/exceptions.py

Copy file name to clipboardExpand all lines: sklearn/exceptions.py
+9Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
__all__ = ['NotFittedError',
7+
'InvalidParameterError',
78
'SignatureError',
89
'ChangedBehaviorWarning',
910
'ConvergenceWarning',
@@ -38,6 +39,14 @@ class NotFittedError(ValueError, AttributeError):
3839
"""
3940

4041

42+
class InvalidParameterError(ValueError):
43+
""" TODO
44+
"""
45+
46+
def __init__(self, param_name):
47+
self.param_name = param_name
48+
49+
4150
class SignatureError(RuntimeError):
4251
""" TODO
4352
"""

‎sklearn/gaussian_process/kernels.py

Copy file name to clipboardExpand all lines: sklearn/gaussian_process/kernels.py
+13-40Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
from scipy.special import kv, gamma
2828
from scipy.spatial.distance import pdist, cdist, squareform
2929

30-
from ..base import clone
31-
from ..exceptions import SignatureError
30+
from ..base import clone, GetSetParamsMixin
31+
from ..exceptions import InvalidParameterError, SignatureError
3232
from ..metrics.pairwise import pairwise_kernels
33-
from ..utils import get_param_names_from_constructor
3433

3534

3635
def _check_length_scale(X, length_scale):
@@ -116,7 +115,7 @@ def __eq__(self, other):
116115
self.fixed == other.fixed)
117116

118117

119-
class Kernel(metaclass=ABCMeta):
118+
class Kernel(GetSetParamsMixin, metaclass=ABCMeta):
120119
"""Base class for all kernels.
121120
122121
.. versionadded:: 0.18
@@ -127,31 +126,23 @@ def get_params(self, deep=True):
127126
128127
Parameters
129128
----------
130-
deep : boolean, optional
131-
If True, will return the parameters for this estimator and
132-
contained subobjects that are estimators.
129+
deep : Ignored
133130
134131
Returns
135132
-------
136133
params : mapping of string to any
137134
Parameter names mapped to their values.
138135
"""
139-
params = dict()
140-
141-
# introspect the constructor arguments to find the model parameters
142-
# to represent
143136
try:
144-
param_names = get_param_names_from_constructor(self.__class__)
137+
# for Kernels, get_params is always shallow
138+
return super().get_params(deep=False)
145139
except SignatureError as e:
146140
raise SignatureError("scikit-learn kernels should always "
147141
"specify their parameters in the signature"
148142
" of their __init__ (no varargs)."
149143
" %s with constructor %s doesn't "
150144
" follow this convention."
151145
% (self.__class__, e.signature))
152-
for key in param_names:
153-
params[key] = getattr(self, key, None)
154-
return params
155146

156147
def set_params(self, **params):
157148
"""Set the parameters of this kernel.
@@ -164,31 +155,13 @@ def set_params(self, **params):
164155
-------
165156
self
166157
"""
167-
if not params:
168-
# Simple optimisation to gain speed (inspect is slow)
169-
return self
170-
valid_params = self.get_params(deep=True)
171-
for key, value in params.items():
172-
split = key.split('__', 1)
173-
if len(split) > 1:
174-
# nested objects case
175-
name, sub_name = split
176-
if name not in valid_params:
177-
raise ValueError('Invalid parameter %s for kernel %s. '
178-
'Check the list of available parameters '
179-
'with `kernel.get_params().keys()`.' %
180-
(name, self))
181-
sub_object = valid_params[name]
182-
sub_object.set_params(**{sub_name: value})
183-
else:
184-
# simple objects case
185-
if key not in valid_params:
186-
raise ValueError('Invalid parameter %s for kernel %s. '
187-
'Check the list of available parameters '
188-
'with `kernel.get_params().keys()`.' %
189-
(key, self.__class__.__name__))
190-
setattr(self, key, value)
191-
return self
158+
try:
159+
return super().set_params(**params)
160+
except InvalidParameterError as e:
161+
raise InvalidParameterError('Invalid parameter %s for kernel %s. '
162+
'Check the list of available parameters '
163+
'with `kernel.get_params().keys()`.' %
164+
(e.param_name, self))
192165

193166
def clone_with_theta(self, theta):
194167
"""Returns a clone of self with given hyperparameters theta.

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.