13
13
import numpy as np
14
14
15
15
from . import __version__
16
- from .exceptions import SignatureError
16
+ from .exceptions import InvalidParameterError , SignatureError
17
17
from .utils import _IS_32BIT , get_param_names_from_constructor
18
18
19
19
_DEFAULT_TAGS = {
@@ -141,18 +141,9 @@ def _update_if_consistent(dict1, dict2):
141
141
return dict1
142
142
143
143
144
- class BaseEstimator :
145
- """Base class for all estimators in scikit-learn
146
-
147
- Notes
148
- -----
149
- All estimators should specify all the parameters that can be set
150
- at the class level in their ``__init__`` as explicit keyword
151
- arguments (no ``*args`` or ``**kwargs``).
152
- """
153
-
144
+ class GetSetParamsMixin :
154
145
def _get_params_from (self , param_names , deep = True ):
155
- """Get parameters for this estimator from the given names"""
146
+ """Get parameters for this object from the given names"""
156
147
out = dict ()
157
148
for key in param_names :
158
149
value = getattr (self , key , None )
@@ -163,35 +154,26 @@ def _get_params_from(self, param_names, deep=True):
163
154
return out
164
155
165
156
def get_params (self , deep = True ):
166
- """Get parameters for this estimator .
157
+ """Get parameters for this object .
167
158
168
159
Parameters
169
160
----------
170
161
deep : boolean, optional
171
- If True, will return the parameters for this estimator and
172
- contained subobjects that are estimators .
162
+ If True, will return the parameters for this object and
163
+ contained subobjects that implement ``get_params`` .
173
164
174
165
Returns
175
166
-------
176
167
params : mapping of string to any
177
168
Parameter names mapped to their values.
178
169
"""
179
- try :
180
- param_names = get_param_names_from_constructor (self .__class__ )
181
- except SignatureError as e :
182
- raise SignatureError ("scikit-learn estimators should always "
183
- "specify their parameters in the signature"
184
- " of their __init__ (no varargs)."
185
- " %s with constructor %s doesn't "
186
- " follow this convention."
187
- % (self .__class__ , e .signature ))
188
-
170
+ param_names = get_param_names_from_constructor (self .__class__ )
189
171
return self ._get_params_from (param_names , deep = deep )
190
172
191
173
def set_params (self , ** params ):
192
- """Set the parameters of this estimator .
174
+ """Set the parameters of this object .
193
175
194
- The method works on simple estimators as well as on nested objects
176
+ The method works on simple objects as well as on nested objects
195
177
(such as pipelines). The latter have parameters of the form
196
178
``<component>__<parameter>`` so that it's possible to update each
197
179
component of a nested object.
@@ -209,10 +191,7 @@ def set_params(self, **params):
209
191
for key , value in params .items ():
210
192
key , delim , sub_key = key .partition ('__' )
211
193
if key not in valid_params :
212
- raise ValueError ('Invalid parameter %s for estimator %s. '
213
- 'Check the list of available parameters '
214
- 'with `estimator.get_params().keys()`.' %
215
- (key , self ))
194
+ raise InvalidParameterError (key )
216
195
217
196
if delim :
218
197
nested_params [key ][sub_key ] = value
@@ -225,6 +204,61 @@ def set_params(self, **params):
225
204
226
205
return self
227
206
207
+
208
+ class BaseEstimator (GetSetParamsMixin ):
209
+ """Base class for all estimators in scikit-learn
210
+
211
+ Notes
212
+ -----
213
+ All estimators should specify all the parameters that can be set
214
+ at the class level in their ``__init__`` as explicit keyword
215
+ arguments (no ``*args`` or ``**kwargs``).
216
+ """
217
+
218
+ def get_params (self , deep = True ):
219
+ """Get parameters for this estimator.
220
+
221
+ Parameters
222
+ ----------
223
+ deep : boolean, optional
224
+ If True, will return the parameters for this estimator and
225
+ contained subobjects that are estimators.
226
+
227
+ Returns
228
+ -------
229
+ params : mapping of string to any
230
+ Parameter names mapped to their values.
231
+ """
232
+ try :
233
+ return super ().get_params (deep = deep )
234
+ except SignatureError as e :
235
+ raise SignatureError ("scikit-learn estimators should always "
236
+ "specify their parameters in the signature"
237
+ " of their __init__ (no varargs)."
238
+ " %s with constructor %s doesn't "
239
+ " follow this convention."
240
+ % (self .__class__ , e .signature ))
241
+
242
+ def set_params (self , ** params ):
243
+ """Set the parameters of this estimator.
244
+
245
+ The method works on simple estimators as well as on nested objects
246
+ (such as pipelines). The latter have parameters of the form
247
+ ``<component>__<parameter>`` so that it's possible to update each
248
+ component of a nested object.
249
+
250
+ Returns
251
+ -------
252
+ self
253
+ """
254
+ try :
255
+ return super ().set_params (** params )
256
+ except InvalidParameterError as e :
257
+ raise InvalidParameterError ('Invalid parameter %s for estimator %s. '
258
+ 'Check the list of available parameters '
259
+ 'with `estimator.get_params().keys()`.' %
260
+ (e .param_name , self ))
261
+
228
262
def __repr__ (self , N_CHAR_MAX = 700 ):
229
263
# N_CHAR_MAX is the (approximate) maximum number of non-blank
230
264
# characters to render. We pass it as an optional parameter to ease
0 commit comments