Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit cc63e1c

Browse filesBrowse files
committed
ENH slight improvement of common tests.
1 parent 5b822a2 commit cc63e1c
Copy full SHA for cc63e1c

File tree

Expand file treeCollapse file tree

4 files changed

+105
-51
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+105
-51
lines changed

‎doc/developers/utilities.rst

Copy file name to clipboardExpand all lines: doc/developers/utilities.rst
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ Testing Functions
256256
- :class:`mock_urllib2`: Object which mocks the urllib2 module to fake
257257
requests of mldata. Used in tests of :mod:`sklearn.datasets`.
258258

259+
- :func:`testing.all_estimators` : returns a list of all estimators in
260+
sklearn to test for consistent behavior and interfaces.
261+
259262

260263
Helper Functions
261264
================

‎sklearn/ensemble/base.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/base.py
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# Authors: Gilles Louppe
66
# License: BSD 3
77

8+
from abc import ABCMeta, abstractmethod
9+
810
from ..base import clone
911
from ..base import BaseEstimator
1012
from ..base import MetaEstimatorMixin
@@ -28,7 +30,9 @@ class BaseEnsemble(BaseEstimator, MetaEstimatorMixin):
2830
The list of attributes to use as parameters when instantiating a
2931
new base estimator. If none are given, default parameters are used.
3032
"""
33+
__metaclass__ = ABCMeta
3134

35+
@abstractmethod
3236
def __init__(self, base_estimator, n_estimators=10,
3337
estimator_params=tuple()):
3438

‎sklearn/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_common.py
+27-49Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sklearn.utils.testing import assert_array_equal
2121
from sklearn.utils.testing import assert_array_almost_equal
2222
from sklearn.utils.testing import all_estimators
23+
from sklearn.utils.testing import meta_estimators
2324
from sklearn.utils.testing import set_random_state
2425
from sklearn.utils.testing import assert_greater
2526

@@ -34,14 +35,10 @@
3435
from sklearn.svm.base import BaseLibSVM
3536

3637
# import "special" estimators
37-
from sklearn.grid_search import GridSearchCV
3838
from sklearn.decomposition import SparseCoder
39-
from sklearn.pipeline import Pipeline, FeatureUnion
4039
from sklearn.pls import _PLS, PLSCanonical, PLSRegression, CCA, PLSSVD
41-
from sklearn.ensemble import BaseEnsemble, RandomTreesEmbedding
42-
from sklearn.multiclass import (OneVsOneClassifier, OneVsRestClassifier,
43-
OutputCodeClassifier)
44-
from sklearn.feature_selection import RFE, RFECV, SelectKBest
40+
from sklearn.ensemble import RandomTreesEmbedding
41+
from sklearn.feature_selection import SelectKBest
4542
from sklearn.dummy import DummyClassifier, DummyRegressor
4643
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
4744
from sklearn.covariance import EllipticEnvelope, EllipticEnvelop
@@ -56,19 +53,16 @@
5653
from sklearn.random_projection import (GaussianRandomProjection,
5754
SparseRandomProjection)
5855

59-
dont_test = [Pipeline, FeatureUnion, GridSearchCV, SparseCoder,
60-
EllipticEnvelope, EllipticEnvelop, DictVectorizer, LabelBinarizer,
61-
LabelEncoder, TfidfTransformer, IsotonicRegression, OneHotEncoder,
62-
RandomTreesEmbedding, FeatureHasher, DummyClassifier,
63-
DummyRegressor]
64-
meta_estimators = [BaseEnsemble, OneVsOneClassifier, OutputCodeClassifier,
65-
OneVsRestClassifier, RFE, RFECV]
56+
dont_test = [SparseCoder, EllipticEnvelope, EllipticEnvelop, DictVectorizer,
57+
LabelBinarizer, LabelEncoder, TfidfTransformer,
58+
IsotonicRegression, OneHotEncoder, RandomTreesEmbedding,
59+
FeatureHasher, DummyClassifier, DummyRegressor]
6660

6761

6862
def test_all_estimators():
6963
# Test that estimators are default-constructible, clonable
7064
# and have working repr.
71-
estimators = all_estimators()
65+
estimators = all_estimators(include_meta_estimators=True)
7266
clf = LDA()
7367

7468
for name, E in estimators:
@@ -78,7 +72,7 @@ def test_all_estimators():
7872
# test default-constructibility
7973
# get rid of deprecation warnings
8074
with warnings.catch_warnings(record=True):
81-
if E in meta_estimators:
75+
if name in meta_estimators:
8276
e = E(clf)
8377
else:
8478
e = E()
@@ -101,7 +95,7 @@ def test_all_estimators():
10195
# true for mixins
10296
continue
10397
params = e.get_params()
104-
if E in meta_estimators:
98+
if name in meta_estimators:
10599
# they need a non-default argument
106100
args = args[2:]
107101
else:
@@ -130,7 +124,7 @@ def test_estimators_sparse_data():
130124
estimators = [(name, E) for name, E in estimators
131125
if issubclass(E, (ClassifierMixin, RegressorMixin))]
132126
for name, Clf in estimators:
133-
if Clf in dont_test or Clf in meta_estimators:
127+
if Clf in dont_test:
134128
continue
135129
# catch deprecation warnings
136130
with warnings.catch_warnings(record=True):
@@ -154,9 +148,7 @@ def test_estimators_sparse_data():
154148
def test_transformers():
155149
# test if transformers do something sensible on training set
156150
# also test all shapes / shape errors
157-
estimators = all_estimators()
158-
transformers = [(name, E) for name, E in estimators
159-
if issubclass(E, TransformerMixin)]
151+
transformers = all_estimators(type_filter='transformer')
160152
X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
161153
random_state=0, n_features=2, cluster_std=0.1)
162154
n_samples, n_features = X.shape
@@ -168,7 +160,7 @@ def test_transformers():
168160
for name, Trans in transformers:
169161
trans = None
170162

171-
if Trans in dont_test or Trans in meta_estimators:
163+
if Trans in dont_test:
172164
continue
173165
# these don't actually fit the data:
174166
if Trans in [AdditiveChi2Sampler, Binarizer, Normalizer]:
@@ -244,11 +236,9 @@ def test_transformers_sparse_data():
244236
X[X < .8] = 0
245237
X = sparse.csr_matrix(X)
246238
y = (4 * rng.rand(40)).astype(np.int)
247-
estimators = all_estimators()
248-
estimators = [(name, E) for name, E in estimators
249-
if issubclass(E, TransformerMixin)]
239+
estimators = all_estimators(type_filter='transformer')
250240
for name, Trans in estimators:
251-
if Trans in dont_test or Trans in meta_estimators:
241+
if Trans in dont_test:
252242
continue
253243
# catch deprecation warnings
254244
with warnings.catch_warnings(record=True):
@@ -302,7 +292,7 @@ def test_estimators_nan_inf():
302292
" transform.")
303293
for X_train in [X_train_nan, X_train_inf]:
304294
for name, Est in estimators:
305-
if Est in dont_test or Est in meta_estimators:
295+
if Est in dont_test:
306296
continue
307297
if Est in (_PLS, PLSCanonical, PLSRegression, CCA, PLSSVD):
308298
continue
@@ -383,14 +373,12 @@ def test_classifiers_one_label():
383373
X_train = rnd.uniform(size=(10, 3))
384374
X_test = rnd.uniform(size=(10, 3))
385375
y = np.ones(10)
386-
estimators = all_estimators()
387-
classifiers = [(name, E) for name, E in estimators
388-
if issubclass(E, ClassifierMixin)]
376+
classifiers = all_estimators(type_filter='classifier')
389377
error_string_fit = "Classifier can't train when only one class is present."
390378
error_string_predict = ("Classifier can't predict when only one class is "
391379
"present.")
392380
for name, Clf in classifiers:
393-
if Clf in dont_test or Clf in meta_estimators:
381+
if Clf in dont_test:
394382
continue
395383
# catch deprecation warnings
396384
with warnings.catch_warnings(record=True):
@@ -420,9 +408,7 @@ def test_classifiers_one_label():
420408
def test_clustering():
421409
# test if clustering algorithms do something sensible
422410
# also test all shapes / shape errors
423-
estimators = all_estimators()
424-
clustering = [(name, E) for name, E in estimators
425-
if issubclass(E, ClusterMixin)]
411+
clustering = all_estimators(type_filter='cluster')
426412
iris = load_iris()
427413
X, y = iris.data, iris.target
428414
X, y = shuffle(X, y, random_state=7)
@@ -460,9 +446,7 @@ def test_clustering():
460446
def test_classifiers_train():
461447
# test if classifiers do something sensible on training set
462448
# also test all shapes / shape errors
463-
estimators = all_estimators()
464-
classifiers = [(name, E) for name, E in estimators
465-
if issubclass(E, ClassifierMixin)]
449+
classifiers = all_estimators(type_filter='classifier')
466450
X_m, y_m = make_blobs(random_state=0)
467451
X_m, y_m = shuffle(X_m, y_m, random_state=7)
468452
X_m = StandardScaler().fit_transform(X_m)
@@ -475,7 +459,7 @@ def test_classifiers_train():
475459
n_classes = len(classes)
476460
n_samples, n_features = X.shape
477461
for name, Clf in classifiers:
478-
if Clf in dont_test or Clf in meta_estimators:
462+
if Clf in dont_test:
479463
continue
480464
if Clf in [MultinomialNB, BernoulliNB]:
481465
# TODO also test these!
@@ -538,17 +522,15 @@ def test_classifiers_train():
538522

539523
def test_classifiers_classes():
540524
# test if classifiers can cope with non-consecutive classes
541-
estimators = all_estimators()
542-
classifiers = [(name, E) for name, E in estimators
543-
if issubclass(E, ClassifierMixin)]
525+
classifiers = all_estimators(type_filter='classifier')
544526
X, y = make_blobs(random_state=12345)
545527
X, y = shuffle(X, y, random_state=7)
546528
X = StandardScaler().fit_transform(X)
547529
y = 2 * y + 1
548530
# TODO: make work with next line :)
549531
#y = y.astype(np.str)
550532
for name, Clf in classifiers:
551-
if Clf in dont_test or Clf in meta_estimators:
533+
if Clf in dont_test:
552534
continue
553535
if Clf in [MultinomialNB, BernoulliNB]:
554536
# TODO also test these!
@@ -569,16 +551,14 @@ def test_classifiers_classes():
569551
def test_regressors_int():
570552
# test if regressors can cope with integer labels (by converting them to
571553
# float)
572-
estimators = all_estimators()
573-
regressors = [(name, E) for name, E in estimators
574-
if issubclass(E, RegressorMixin)]
554+
regressors = all_estimators(type_filter='regressor')
575555
boston = load_boston()
576556
X, y = boston.data, boston.target
577557
X, y = shuffle(X, y, random_state=0)
578558
X = StandardScaler().fit_transform(X)
579559
y = np.random.randint(2, size=X.shape[0])
580560
for name, Reg in regressors:
581-
if Reg in dont_test or Reg in meta_estimators or Reg in (CCA,):
561+
if Reg in dont_test or Reg in (CCA,):
582562
continue
583563
# catch deprecation warnings
584564
with warnings.catch_warnings(record=True):
@@ -603,9 +583,7 @@ def test_regressors_int():
603583

604584

605585
def test_regressors_train():
606-
estimators = all_estimators()
607-
regressors = [(name, E) for name, E in estimators
608-
if issubclass(E, RegressorMixin)]
586+
regressors = all_estimators(type_filter='regressor')
609587
boston = load_boston()
610588
X, y = boston.data, boston.target
611589
X, y = shuffle(X, y, random_state=0)
@@ -615,7 +593,7 @@ def test_regressors_train():
615593
y = StandardScaler().fit_transform(y)
616594
succeeded = True
617595
for name, Reg in regressors:
618-
if Reg in dont_test or Reg in meta_estimators:
596+
if Reg in dont_test:
619597
continue
620598
# catch deprecation warnings
621599
with warnings.catch_warnings(record=True):

‎sklearn/utils/testing.py

Copy file name to clipboardExpand all lines: sklearn/utils/testing.py
+71-2Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
from numpy.testing import assert_array_almost_equal
3535
from numpy.testing import assert_array_less
3636

37+
from sklearn.base import (ClassifierMixin, RegressorMixin, TransformerMixin,
38+
ClusterMixin)
39+
3740
__all__ = ["assert_equal", "assert_not_equal", "assert_raises", "raises",
3841
"with_setup", "assert_true", "assert_false", "assert_almost_equal",
3942
"assert_array_equal", "assert_array_almost_equal",
@@ -160,8 +163,48 @@ def urlopen(self, urlname):
160163
def quote(self, string, safe='/'):
161164
return urllib2.quote(string, safe)
162165

166+
# Meta estimators need another estimator to be instantiated.
167+
meta_estimators = ["OneVsOneClassifier",
168+
"OutputCodeClassifier", "OneVsRestClassifier", "RFE",
169+
"RFECV"]
170+
# estimators that there is no way to default-construct sensibly
171+
other = ["Pipeline", "FeatureUnion", "GridSearchCV"]
172+
173+
174+
def all_estimators(include_meta_estimators=False, include_other=False,
175+
type_filter=None):
176+
"""Get a list of all estimators from sklearn.
163177
164-
def all_estimators():
178+
This function crawls the module and gets all classes that inherit
179+
from BaseEstimator. Classes that are defined in test-modules are not
180+
included.
181+
By default meta_estimators such as GridSearchCV are also not included.
182+
183+
Parameters
184+
----------
185+
include_meta_estimators : boolean, default=False
186+
Whether to include meta-estimators that can be constructed using
187+
an estimator as their first argument. These are currently
188+
OneVsOneClassifier, OutputCodeClassifier, OneVsRestClassifier, RFE,
189+
RFECV.
190+
191+
include_others : boolean, default=False
192+
Wether to include meta-estimators that are somehow special and can
193+
not be default-constructed sensibly. These are currently
194+
Pipeline, FeatureUnion and GridSearchCV
195+
196+
type_filter : string or None, default=None
197+
Which kind of estimators should be returned. If None, no filter is
198+
applied and all estimators are returned. Possible values are
199+
'classifier', 'regressor', 'cluster' and 'transformer' to get
200+
estimators only of these specific types.
201+
202+
Returns
203+
-------
204+
estimators : list of tuples
205+
List of (name, class), where ``name`` is the class name as string
206+
and ``class`` is the actuall type of the class.
207+
"""
165208
def is_abstract(c):
166209
if not(hasattr(c, '__abstractmethods__')):
167210
return False
@@ -182,9 +225,35 @@ def is_abstract(c):
182225

183226
all_classes = set(all_classes)
184227

185-
estimators = [c for c in all_classes if issubclass(c[1], BaseEstimator)]
228+
estimators = [c for c in all_classes
229+
if (issubclass(c[1], BaseEstimator)
230+
and c[0] != 'BaseEstimator')]
186231
# get rid of abstract base classes
187232
estimators = [c for c in estimators if not is_abstract(c[1])]
233+
234+
if not include_other:
235+
estimators = [c for c in estimators if not c[0] in other]
236+
# possibly get rid of meta estimators
237+
if not include_meta_estimators:
238+
estimators = [c for c in estimators if not c[0] in meta_estimators]
239+
240+
if type_filter == 'classifier':
241+
estimators = [est for est in estimators
242+
if issubclass(est[1], ClassifierMixin)]
243+
elif type_filter == 'regressor':
244+
estimators = [est for est in estimators
245+
if issubclass(est[1], RegressorMixin)]
246+
elif type_filter == 'transformer':
247+
estimators = [est for est in estimators
248+
if issubclass(est[1], TransformerMixin)]
249+
elif type_filter == 'cluster':
250+
estimators = [est for est in estimators
251+
if issubclass(est[1], ClusterMixin)]
252+
elif type_filter is not None:
253+
raise ValueError("Parmeter type_filter must be 'classifier', "
254+
"'regressor', 'transformer', 'cluster' or None, got"
255+
" %s." % repr(type_filter))
256+
188257
# We sort in order to have reproducible test failures
189258
return sorted(estimators)
190259

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.