20
20
from sklearn .utils .testing import assert_array_equal
21
21
from sklearn .utils .testing import assert_array_almost_equal
22
22
from sklearn .utils .testing import all_estimators
23
+ from sklearn .utils .testing import meta_estimators
23
24
from sklearn .utils .testing import set_random_state
24
25
from sklearn .utils .testing import assert_greater
25
26
34
35
from sklearn .svm .base import BaseLibSVM
35
36
36
37
# import "special" estimators
37
- from sklearn .grid_search import GridSearchCV
38
38
from sklearn .decomposition import SparseCoder
39
- from sklearn .pipeline import Pipeline , FeatureUnion
40
39
from sklearn .pls import _PLS , PLSCanonical , PLSRegression , CCA , PLSSVD
41
- from sklearn .ensemble import BaseEnsemble , RandomTreesEmbedding
42
- from sklearn .multiclass import (OneVsOneClassifier , OneVsRestClassifier ,
43
- OutputCodeClassifier )
44
- from sklearn .feature_selection import RFE , RFECV , SelectKBest
40
+ from sklearn .ensemble import RandomTreesEmbedding
41
+ from sklearn .feature_selection import SelectKBest
45
42
from sklearn .dummy import DummyClassifier , DummyRegressor
46
43
from sklearn .naive_bayes import MultinomialNB , BernoulliNB
47
44
from sklearn .covariance import EllipticEnvelope , EllipticEnvelop
56
53
from sklearn .random_projection import (GaussianRandomProjection ,
57
54
SparseRandomProjection )
58
55
59
- dont_test = [Pipeline , FeatureUnion , GridSearchCV , SparseCoder ,
60
- EllipticEnvelope , EllipticEnvelop , DictVectorizer , LabelBinarizer ,
61
- LabelEncoder , TfidfTransformer , IsotonicRegression , OneHotEncoder ,
62
- RandomTreesEmbedding , FeatureHasher , DummyClassifier ,
63
- DummyRegressor ]
64
- meta_estimators = [BaseEnsemble , OneVsOneClassifier , OutputCodeClassifier ,
65
- OneVsRestClassifier , RFE , RFECV ]
56
+ dont_test = [SparseCoder , EllipticEnvelope , EllipticEnvelop , DictVectorizer ,
57
+ LabelBinarizer , LabelEncoder , TfidfTransformer ,
58
+ IsotonicRegression , OneHotEncoder , RandomTreesEmbedding ,
59
+ FeatureHasher , DummyClassifier , DummyRegressor ]
66
60
67
61
68
62
def test_all_estimators ():
69
63
# Test that estimators are default-constructible, clonable
70
64
# and have working repr.
71
- estimators = all_estimators ()
65
+ estimators = all_estimators (include_meta_estimators = True )
72
66
clf = LDA ()
73
67
74
68
for name , E in estimators :
@@ -78,7 +72,7 @@ def test_all_estimators():
78
72
# test default-constructibility
79
73
# get rid of deprecation warnings
80
74
with warnings .catch_warnings (record = True ):
81
- if E in meta_estimators :
75
+ if name in meta_estimators :
82
76
e = E (clf )
83
77
else :
84
78
e = E ()
@@ -101,7 +95,7 @@ def test_all_estimators():
101
95
# true for mixins
102
96
continue
103
97
params = e .get_params ()
104
- if E in meta_estimators :
98
+ if name in meta_estimators :
105
99
# they need a non-default argument
106
100
args = args [2 :]
107
101
else :
@@ -130,7 +124,7 @@ def test_estimators_sparse_data():
130
124
estimators = [(name , E ) for name , E in estimators
131
125
if issubclass (E , (ClassifierMixin , RegressorMixin ))]
132
126
for name , Clf in estimators :
133
- if Clf in dont_test or Clf in meta_estimators :
127
+ if Clf in dont_test :
134
128
continue
135
129
# catch deprecation warnings
136
130
with warnings .catch_warnings (record = True ):
@@ -154,9 +148,7 @@ def test_estimators_sparse_data():
154
148
def test_transformers ():
155
149
# test if transformers do something sensible on training set
156
150
# also test all shapes / shape errors
157
- estimators = all_estimators ()
158
- transformers = [(name , E ) for name , E in estimators
159
- if issubclass (E , TransformerMixin )]
151
+ transformers = all_estimators (type_filter = 'transformer' )
160
152
X , y = make_blobs (n_samples = 30 , centers = [[0 , 0 , 0 ], [1 , 1 , 1 ]],
161
153
random_state = 0 , n_features = 2 , cluster_std = 0.1 )
162
154
n_samples , n_features = X .shape
@@ -168,7 +160,7 @@ def test_transformers():
168
160
for name , Trans in transformers :
169
161
trans = None
170
162
171
- if Trans in dont_test or Trans in meta_estimators :
163
+ if Trans in dont_test :
172
164
continue
173
165
# these don't actually fit the data:
174
166
if Trans in [AdditiveChi2Sampler , Binarizer , Normalizer ]:
@@ -244,11 +236,9 @@ def test_transformers_sparse_data():
244
236
X [X < .8 ] = 0
245
237
X = sparse .csr_matrix (X )
246
238
y = (4 * rng .rand (40 )).astype (np .int )
247
- estimators = all_estimators ()
248
- estimators = [(name , E ) for name , E in estimators
249
- if issubclass (E , TransformerMixin )]
239
+ estimators = all_estimators (type_filter = 'transformer' )
250
240
for name , Trans in estimators :
251
- if Trans in dont_test or Trans in meta_estimators :
241
+ if Trans in dont_test :
252
242
continue
253
243
# catch deprecation warnings
254
244
with warnings .catch_warnings (record = True ):
@@ -302,7 +292,7 @@ def test_estimators_nan_inf():
302
292
" transform." )
303
293
for X_train in [X_train_nan , X_train_inf ]:
304
294
for name , Est in estimators :
305
- if Est in dont_test or Est in meta_estimators :
295
+ if Est in dont_test :
306
296
continue
307
297
if Est in (_PLS , PLSCanonical , PLSRegression , CCA , PLSSVD ):
308
298
continue
@@ -383,14 +373,12 @@ def test_classifiers_one_label():
383
373
X_train = rnd .uniform (size = (10 , 3 ))
384
374
X_test = rnd .uniform (size = (10 , 3 ))
385
375
y = np .ones (10 )
386
- estimators = all_estimators ()
387
- classifiers = [(name , E ) for name , E in estimators
388
- if issubclass (E , ClassifierMixin )]
376
+ classifiers = all_estimators (type_filter = 'classifier' )
389
377
error_string_fit = "Classifier can't train when only one class is present."
390
378
error_string_predict = ("Classifier can't predict when only one class is "
391
379
"present." )
392
380
for name , Clf in classifiers :
393
- if Clf in dont_test or Clf in meta_estimators :
381
+ if Clf in dont_test :
394
382
continue
395
383
# catch deprecation warnings
396
384
with warnings .catch_warnings (record = True ):
@@ -420,9 +408,7 @@ def test_classifiers_one_label():
420
408
def test_clustering ():
421
409
# test if clustering algorithms do something sensible
422
410
# also test all shapes / shape errors
423
- estimators = all_estimators ()
424
- clustering = [(name , E ) for name , E in estimators
425
- if issubclass (E , ClusterMixin )]
411
+ clustering = all_estimators (type_filter = 'cluster' )
426
412
iris = load_iris ()
427
413
X , y = iris .data , iris .target
428
414
X , y = shuffle (X , y , random_state = 7 )
@@ -460,9 +446,7 @@ def test_clustering():
460
446
def test_classifiers_train ():
461
447
# test if classifiers do something sensible on training set
462
448
# also test all shapes / shape errors
463
- estimators = all_estimators ()
464
- classifiers = [(name , E ) for name , E in estimators
465
- if issubclass (E , ClassifierMixin )]
449
+ classifiers = all_estimators (type_filter = 'classifier' )
466
450
X_m , y_m = make_blobs (random_state = 0 )
467
451
X_m , y_m = shuffle (X_m , y_m , random_state = 7 )
468
452
X_m = StandardScaler ().fit_transform (X_m )
@@ -475,7 +459,7 @@ def test_classifiers_train():
475
459
n_classes = len (classes )
476
460
n_samples , n_features = X .shape
477
461
for name , Clf in classifiers :
478
- if Clf in dont_test or Clf in meta_estimators :
462
+ if Clf in dont_test :
479
463
continue
480
464
if Clf in [MultinomialNB , BernoulliNB ]:
481
465
# TODO also test these!
@@ -538,17 +522,15 @@ def test_classifiers_train():
538
522
539
523
def test_classifiers_classes ():
540
524
# test if classifiers can cope with non-consecutive classes
541
- estimators = all_estimators ()
542
- classifiers = [(name , E ) for name , E in estimators
543
- if issubclass (E , ClassifierMixin )]
525
+ classifiers = all_estimators (type_filter = 'classifier' )
544
526
X , y = make_blobs (random_state = 12345 )
545
527
X , y = shuffle (X , y , random_state = 7 )
546
528
X = StandardScaler ().fit_transform (X )
547
529
y = 2 * y + 1
548
530
# TODO: make work with next line :)
549
531
#y = y.astype(np.str)
550
532
for name , Clf in classifiers :
551
- if Clf in dont_test or Clf in meta_estimators :
533
+ if Clf in dont_test :
552
534
continue
553
535
if Clf in [MultinomialNB , BernoulliNB ]:
554
536
# TODO also test these!
@@ -569,16 +551,14 @@ def test_classifiers_classes():
569
551
def test_regressors_int ():
570
552
# test if regressors can cope with integer labels (by converting them to
571
553
# float)
572
- estimators = all_estimators ()
573
- regressors = [(name , E ) for name , E in estimators
574
- if issubclass (E , RegressorMixin )]
554
+ regressors = all_estimators (type_filter = 'regressor' )
575
555
boston = load_boston ()
576
556
X , y = boston .data , boston .target
577
557
X , y = shuffle (X , y , random_state = 0 )
578
558
X = StandardScaler ().fit_transform (X )
579
559
y = np .random .randint (2 , size = X .shape [0 ])
580
560
for name , Reg in regressors :
581
- if Reg in dont_test or Reg in meta_estimators or Reg in (CCA ,):
561
+ if Reg in dont_test or Reg in (CCA ,):
582
562
continue
583
563
# catch deprecation warnings
584
564
with warnings .catch_warnings (record = True ):
@@ -603,9 +583,7 @@ def test_regressors_int():
603
583
604
584
605
585
def test_regressors_train ():
606
- estimators = all_estimators ()
607
- regressors = [(name , E ) for name , E in estimators
608
- if issubclass (E , RegressorMixin )]
586
+ regressors = all_estimators (type_filter = 'regressor' )
609
587
boston = load_boston ()
610
588
X , y = boston .data , boston .target
611
589
X , y = shuffle (X , y , random_state = 0 )
@@ -615,7 +593,7 @@ def test_regressors_train():
615
593
y = StandardScaler ().fit_transform (y )
616
594
succeeded = True
617
595
for name , Reg in regressors :
618
- if Reg in dont_test or Reg in meta_estimators :
596
+ if Reg in dont_test :
619
597
continue
620
598
# catch deprecation warnings
621
599
with warnings .catch_warnings (record = True ):
0 commit comments