3
3
from joblib import Parallel
4
4
5
5
from ..metrics import check_scoring
6
+ from ..metrics ._scorer import _check_multimetric_scoring , _MultimetricScorer
7
+ from ..model_selection ._validation import _aggregate_score_dicts
6
8
from ..utils import Bunch
7
9
from ..utils import check_random_state
8
10
from ..utils import check_array
@@ -28,24 +30,56 @@ def _calculate_permutation_scores(estimator, X, y, sample_weight, col_idx,
28
30
# (memmap). X.copy() on the other hand is always guaranteed to return a
29
31
# writable data-structure whose columns can be shuffled inplace.
30
32
X_permuted = X .copy ()
31
- scores = np .zeros (n_repeats )
33
+
34
+ scores = []
32
35
shuffling_idx = np .arange (X .shape [0 ])
33
- for n_round in range (n_repeats ):
36
+ for _ in range (n_repeats ):
34
37
random_state .shuffle (shuffling_idx )
35
38
if hasattr (X_permuted , "iloc" ):
36
39
col = X_permuted .iloc [shuffling_idx , col_idx ]
37
40
col .index = X_permuted .index
38
41
X_permuted .iloc [:, col_idx ] = col
39
42
else :
40
43
X_permuted [:, col_idx ] = X_permuted [shuffling_idx , col_idx ]
41
- feature_score = _weights_scorer (
42
- scorer , estimator , X_permuted , y , sample_weight
44
+ scores . append (
45
+ _weights_scorer ( scorer , estimator , X_permuted , y , sample_weight )
43
46
)
44
- scores [n_round ] = feature_score
47
+
48
+ if isinstance (scores [0 ], dict ):
49
+ scores = _aggregate_score_dicts (scores )
50
+ else :
51
+ scores = np .array (scores )
45
52
46
53
return scores
47
54
48
55
56
+ def _create_importances_bunch (baseline_score , permuted_score ):
57
+ """Compute the importances as the decrease in score.
58
+
59
+ Parameters
60
+ ----------
61
+ baseline_score : ndarray of shape (n_features,)
62
+ The baseline score without permutation.
63
+ permuted_score : ndarray of shape (n_features, n_repeats)
64
+ The permuted scores for the `n` repetitions.
65
+
66
+ Returns
67
+ -------
68
+ importances : :class:`~sklearn.utils.Bunch`
69
+ Dictionary-like object, with the following attributes.
70
+ importances_mean : ndarray, shape (n_features, )
71
+ Mean of feature importance over `n_repeats`.
72
+ importances_std : ndarray, shape (n_features, )
73
+ Standard deviation over `n_repeats`.
74
+ importances : ndarray, shape (n_features, n_repeats)
75
+ Raw permutation importance scores.
76
+ """
77
+ importances = baseline_score - permuted_score
78
+ return Bunch (importances_mean = np .mean (importances , axis = 1 ),
79
+ importances_std = np .std (importances , axis = 1 ),
80
+ importances = importances )
81
+
82
+
49
83
@_deprecate_positional_args
50
84
def permutation_importance (estimator , X , y , * , scoring = None , n_repeats = 5 ,
51
85
n_jobs = None , random_state = None , sample_weight = None ):
@@ -74,10 +108,25 @@ def permutation_importance(estimator, X, y, *, scoring=None, n_repeats=5,
74
108
y : array-like or None, shape (n_samples, ) or (n_samples, n_classes)
75
109
Targets for supervised or `None` for unsupervised.
76
110
77
- scoring : string, callable or None, default=None
78
- Scorer to use. It can be a single
79
- string (see :ref:`scoring_parameter`) or a callable (see
80
- :ref:`scoring`). If None, the estimator's default scorer is used.
111
+ scoring : str, callable, list, tuple, or dict, default=None
112
+ Scorer to use.
113
+ If `scoring` represents a single score, one can use:
114
+
115
+ - a single string (see :ref:`scoring_parameter`);
116
+ - a callable (see :ref:`scoring`) that returns a single value.
117
+
118
+ If `scoring` reprents multiple scores, one can use:
119
+
120
+ - a list or tuple of unique strings;
121
+ - a callable returning a dictionary where the keys are the metric
122
+ names and the values are the metric scores;
123
+ - a dictionary with metric names as keys and callables a values.
124
+
125
+ Passing multiple scores to `scoring` is more efficient than calling
126
+ `permutation_importance` for each of the scores as it reuses
127
+ predictions to avoid redundant computation.
128
+
129
+ If None, the estimator's default scorer is used.
81
130
82
131
n_repeats : int, default=5
83
132
Number of times to permute a feature.
@@ -102,16 +151,20 @@ def permutation_importance(estimator, X, y, *, scoring=None, n_repeats=5,
102
151
103
152
Returns
104
153
-------
105
- result : :class:`~sklearn.utils.Bunch`
154
+ result : :class:`~sklearn.utils.Bunch` or dict of such instances
106
155
Dictionary-like object, with the following attributes.
107
156
108
- importances_mean : ndarray, shape (n_features, )
157
+ importances_mean : ndarray of shape (n_features, )
109
158
Mean of feature importance over `n_repeats`.
110
- importances_std : ndarray, shape (n_features, )
159
+ importances_std : ndarray of shape (n_features, )
111
160
Standard deviation over `n_repeats`.
112
- importances : ndarray, shape (n_features, n_repeats)
161
+ importances : ndarray of shape (n_features, n_repeats)
113
162
Raw permutation importance scores.
114
163
164
+ If there are multiple scoring metrics in the scoring parameter
165
+ `result` is a dict with scorer names as keys (e.g. 'roc_auc') and
166
+ `Bunch` objects like above as values.
167
+
115
168
References
116
169
----------
117
170
.. [BRE] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32,
@@ -143,14 +196,33 @@ def permutation_importance(estimator, X, y, *, scoring=None, n_repeats=5,
143
196
random_state = check_random_state (random_state )
144
197
random_seed = random_state .randint (np .iinfo (np .int32 ).max + 1 )
145
198
146
- scorer = check_scoring (estimator , scoring = scoring )
147
- baseline_score = _weights_scorer (scorer , estimator , X , y , sample_weight )
148
-
149
- scores = Parallel (n_jobs = n_jobs )(delayed (_calculate_permutation_scores )(
150
- estimator , X , y , sample_weight , col_idx , random_seed , n_repeats , scorer
151
- ) for col_idx in range (X .shape [1 ]))
152
-
153
- importances = baseline_score - np .array (scores )
154
- return Bunch (importances_mean = np .mean (importances , axis = 1 ),
155
- importances_std = np .std (importances , axis = 1 ),
156
- importances = importances )
199
+ if callable (scoring ):
200
+ scorer = scoring
201
+ elif scoring is None or isinstance (scoring , str ):
202
+ scorer = check_scoring (estimator , scoring = scoring )
203
+ else :
204
+ scorers_dict = _check_multimetric_scoring (estimator , scoring )
205
+ scorer = _MultimetricScorer (** scorers_dict )
206
+
207
+ baseline_score = _weights_scorer (scorer , estimator , X , y ,
208
+ sample_weight )
209
+
210
+ scores = Parallel (n_jobs = n_jobs )(
211
+ delayed (_calculate_permutation_scores )(
212
+ estimator , X , y , sample_weight , col_idx , random_seed ,
213
+ n_repeats , scorer
214
+ ) for col_idx in range (X .shape [1 ]))
215
+
216
+ if isinstance (baseline_score , dict ):
217
+ return {
218
+ name : _create_importances_bunch (
219
+ baseline_score [name ],
220
+ # unpack the permuted scores
221
+ np .array ([
222
+ scores [col_idx ][name ] for col_idx in range (X .shape [1 ])
223
+ ])
224
+ )
225
+ for name in baseline_score
226
+ }
227
+ else :
228
+ return _create_importances_bunch (baseline_score , np .array (scores ))
0 commit comments