Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f58d1eb

Browse filesBrowse files
simonamaggioglemaitreogrisel
authored
ENH Allow multiple scorers input to permutation_importance (#19411)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 4aff385 commit f58d1eb
Copy full SHA for f58d1eb

File tree

4 files changed

+203
-33
lines changed
Filter options

4 files changed

+203
-33
lines changed

‎doc/modules/permutation_importance.rst

Copy file name to clipboardExpand all lines: doc/modules/permutation_importance.rst
+52-8Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@ indicative of how much the model depends on the feature. This technique
1616
benefits from being model agnostic and can be calculated many times with
1717
different permutations of the feature.
1818

19+
.. warning::
20+
21+
Features that are deemed of **low importance for a bad model** (low
22+
cross-validation score) could be **very important for a good model**.
23+
Therefore it is always important to evaluate the predictive power of a model
24+
using a held-out set (or better with cross-validation) prior to computing
25+
importances. Permutation importance does not reflect to the intrinsic
26+
predictive value of a feature by itself but **how important this feature is
27+
for a particular model**.
28+
1929
The :func:`permutation_importance` function calculates the feature importance
2030
of :term:`estimators` for a given dataset. The ``n_repeats`` parameter sets the
2131
number of times a feature is randomly shuffled and returns a sample of feature
@@ -64,15 +74,49 @@ highlight which features contribute the most to the generalization power of the
6474
inspected model. Features that are important on the training set but not on the
6575
held-out set might cause the model to overfit.
6676

67-
.. warning::
77+
The permutation feature importance is the decrease in a model score when a single
78+
feature value is randomly shuffled. The score function to be used for the
79+
computation of importances can be specified with the `scoring` argument,
80+
which also accepts multiple scorers. Using multiple scorers is more computationally
81+
efficient than sequentially calling :func:`permutation_importance` several times
82+
with a different scorer, as it reuses model predictions.
6883

69-
Features that are deemed of **low importance for a bad model** (low
70-
cross-validation score) could be **very important for a good model**.
71-
Therefore it is always important to evaluate the predictive power of a model
72-
using a held-out set (or better with cross-validation) prior to computing
73-
importances. Permutation importance does not reflect to the intrinsic
74-
predictive value of a feature by itself but **how important this feature is
75-
for a particular model**.
84+
An example of using multiple scorers is shown below, employing a list of metrics,
85+
but more input formats are possible, as documented in :ref:`multimetric_scoring`.
86+
87+
>>> scoring = ['r2', 'neg_mean_absolute_percentage_error', 'neg_mean_squared_error']
88+
>>> r_multi = permutation_importance(
89+
... model, X_val, y_val, n_repeats=30, random_state=0, scoring=scoring)
90+
...
91+
>>> for metric in r_multi:
92+
... print(f"{metric}")
93+
... r = r_multi[metric]
94+
... for i in r.importances_mean.argsort()[::-1]:
95+
... if r.importances_mean[i] - 2 * r.importances_std[i] > 0:
96+
... print(f" {diabetes.feature_names[i]:<8}"
97+
... f"{r.importances_mean[i]:.3f}"
98+
... f" +/- {r.importances_std[i]:.3f}")
99+
...
100+
r2
101+
s5 0.204 +/- 0.050
102+
bmi 0.176 +/- 0.048
103+
bp 0.088 +/- 0.033
104+
sex 0.056 +/- 0.023
105+
neg_mean_absolute_percentage_error
106+
s5 0.081 +/- 0.020
107+
bmi 0.064 +/- 0.015
108+
bp 0.029 +/- 0.010
109+
neg_mean_squared_error
110+
s5 1013.903 +/- 246.460
111+
bmi 872.694 +/- 240.296
112+
bp 438.681 +/- 163.025
113+
sex 277.382 +/- 115.126
114+
115+
The ranking of the features is approximately the same for different metrics even
116+
if the scales of the importance values are very different. However, this is not
117+
guaranteed and different metrics might lead to significantly different feature
118+
importances, in particular for models trained for imbalanced classification problems,
119+
for which the choice of the classification metric can be critical.
76120

77121
Outline of the permutation importance algorithm
78122
-----------------------------------------------

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ Changelog
102102
input strings would result in negative indices in the transformed data.
103103
:pr:`19035` by :user:`Liu Yu <ly648499246>`.
104104

105+
:mod:`sklearn.inspection`
106+
.........................
107+
108+
- |Fix| Allow multiple scorers input to
109+
:func:`~sklearn.inspection.permutation_importance`.
110+
:pr:`19411` by :user:`Simona Maggio <simonamaggio>`.
111+
105112
:mod:`sklearn.linear_model`
106113
...........................
107114

‎sklearn/inspection/_permutation_importance.py

Copy file name to clipboardExpand all lines: sklearn/inspection/_permutation_importance.py
+96-24Lines changed: 96 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from joblib import Parallel
44

55
from ..metrics import check_scoring
6+
from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer
7+
from ..model_selection._validation import _aggregate_score_dicts
68
from ..utils import Bunch
79
from ..utils import check_random_state
810
from ..utils import check_array
@@ -28,24 +30,56 @@ def _calculate_permutation_scores(estimator, X, y, sample_weight, col_idx,
2830
# (memmap). X.copy() on the other hand is always guaranteed to return a
2931
# writable data-structure whose columns can be shuffled inplace.
3032
X_permuted = X.copy()
31-
scores = np.zeros(n_repeats)
33+
34+
scores = []
3235
shuffling_idx = np.arange(X.shape[0])
33-
for n_round in range(n_repeats):
36+
for _ in range(n_repeats):
3437
random_state.shuffle(shuffling_idx)
3538
if hasattr(X_permuted, "iloc"):
3639
col = X_permuted.iloc[shuffling_idx, col_idx]
3740
col.index = X_permuted.index
3841
X_permuted.iloc[:, col_idx] = col
3942
else:
4043
X_permuted[:, col_idx] = X_permuted[shuffling_idx, col_idx]
41-
feature_score = _weights_scorer(
42-
scorer, estimator, X_permuted, y, sample_weight
44+
scores.append(
45+
_weights_scorer(scorer, estimator, X_permuted, y, sample_weight)
4346
)
44-
scores[n_round] = feature_score
47+
48+
if isinstance(scores[0], dict):
49+
scores = _aggregate_score_dicts(scores)
50+
else:
51+
scores = np.array(scores)
4552

4653
return scores
4754

4855

56+
def _create_importances_bunch(baseline_score, permuted_score):
57+
"""Compute the importances as the decrease in score.
58+
59+
Parameters
60+
----------
61+
baseline_score : ndarray of shape (n_features,)
62+
The baseline score without permutation.
63+
permuted_score : ndarray of shape (n_features, n_repeats)
64+
The permuted scores for the `n` repetitions.
65+
66+
Returns
67+
-------
68+
importances : :class:`~sklearn.utils.Bunch`
69+
Dictionary-like object, with the following attributes.
70+
importances_mean : ndarray, shape (n_features, )
71+
Mean of feature importance over `n_repeats`.
72+
importances_std : ndarray, shape (n_features, )
73+
Standard deviation over `n_repeats`.
74+
importances : ndarray, shape (n_features, n_repeats)
75+
Raw permutation importance scores.
76+
"""
77+
importances = baseline_score - permuted_score
78+
return Bunch(importances_mean=np.mean(importances, axis=1),
79+
importances_std=np.std(importances, axis=1),
80+
importances=importances)
81+
82+
4983
@_deprecate_positional_args
5084
def permutation_importance(estimator, X, y, *, scoring=None, n_repeats=5,
5185
n_jobs=None, random_state=None, sample_weight=None):
@@ -74,10 +108,25 @@ def permutation_importance(estimator, X, y, *, scoring=None, n_repeats=5,
74108
y : array-like or None, shape (n_samples, ) or (n_samples, n_classes)
75109
Targets for supervised or `None` for unsupervised.
76110
77-
scoring : string, callable or None, default=None
78-
Scorer to use. It can be a single
79-
string (see :ref:`scoring_parameter`) or a callable (see
80-
:ref:`scoring`). If None, the estimator's default scorer is used.
111+
scoring : str, callable, list, tuple, or dict, default=None
112+
Scorer to use.
113+
If `scoring` represents a single score, one can use:
114+
115+
- a single string (see :ref:`scoring_parameter`);
116+
- a callable (see :ref:`scoring`) that returns a single value.
117+
118+
If `scoring` reprents multiple scores, one can use:
119+
120+
- a list or tuple of unique strings;
121+
- a callable returning a dictionary where the keys are the metric
122+
names and the values are the metric scores;
123+
- a dictionary with metric names as keys and callables a values.
124+
125+
Passing multiple scores to `scoring` is more efficient than calling
126+
`permutation_importance` for each of the scores as it reuses
127+
predictions to avoid redundant computation.
128+
129+
If None, the estimator's default scorer is used.
81130
82131
n_repeats : int, default=5
83132
Number of times to permute a feature.
@@ -102,16 +151,20 @@ def permutation_importance(estimator, X, y, *, scoring=None, n_repeats=5,
102151
103152
Returns
104153
-------
105-
result : :class:`~sklearn.utils.Bunch`
154+
result : :class:`~sklearn.utils.Bunch` or dict of such instances
106155
Dictionary-like object, with the following attributes.
107156
108-
importances_mean : ndarray, shape (n_features, )
157+
importances_mean : ndarray of shape (n_features, )
109158
Mean of feature importance over `n_repeats`.
110-
importances_std : ndarray, shape (n_features, )
159+
importances_std : ndarray of shape (n_features, )
111160
Standard deviation over `n_repeats`.
112-
importances : ndarray, shape (n_features, n_repeats)
161+
importances : ndarray of shape (n_features, n_repeats)
113162
Raw permutation importance scores.
114163
164+
If there are multiple scoring metrics in the scoring parameter
165+
`result` is a dict with scorer names as keys (e.g. 'roc_auc') and
166+
`Bunch` objects like above as values.
167+
115168
References
116169
----------
117170
.. [BRE] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32,
@@ -143,14 +196,33 @@ def permutation_importance(estimator, X, y, *, scoring=None, n_repeats=5,
143196
random_state = check_random_state(random_state)
144197
random_seed = random_state.randint(np.iinfo(np.int32).max + 1)
145198

146-
scorer = check_scoring(estimator, scoring=scoring)
147-
baseline_score = _weights_scorer(scorer, estimator, X, y, sample_weight)
148-
149-
scores = Parallel(n_jobs=n_jobs)(delayed(_calculate_permutation_scores)(
150-
estimator, X, y, sample_weight, col_idx, random_seed, n_repeats, scorer
151-
) for col_idx in range(X.shape[1]))
152-
153-
importances = baseline_score - np.array(scores)
154-
return Bunch(importances_mean=np.mean(importances, axis=1),
155-
importances_std=np.std(importances, axis=1),
156-
importances=importances)
199+
if callable(scoring):
200+
scorer = scoring
201+
elif scoring is None or isinstance(scoring, str):
202+
scorer = check_scoring(estimator, scoring=scoring)
203+
else:
204+
scorers_dict = _check_multimetric_scoring(estimator, scoring)
205+
scorer = _MultimetricScorer(**scorers_dict)
206+
207+
baseline_score = _weights_scorer(scorer, estimator, X, y,
208+
sample_weight)
209+
210+
scores = Parallel(n_jobs=n_jobs)(
211+
delayed(_calculate_permutation_scores)(
212+
estimator, X, y, sample_weight, col_idx, random_seed,
213+
n_repeats, scorer
214+
) for col_idx in range(X.shape[1]))
215+
216+
if isinstance(baseline_score, dict):
217+
return {
218+
name: _create_importances_bunch(
219+
baseline_score[name],
220+
# unpack the permuted scores
221+
np.array([
222+
scores[col_idx][name] for col_idx in range(X.shape[1])
223+
])
224+
)
225+
for name in baseline_score
226+
}
227+
else:
228+
return _create_importances_bunch(baseline_score, np.array(scores))

‎sklearn/inspection/tests/test_permutation_importance.py

Copy file name to clipboardExpand all lines: sklearn/inspection/tests/test_permutation_importance.py
+48-1Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
from sklearn.impute import SimpleImputer
1717
from sklearn.inspection import permutation_importance
1818
from sklearn.model_selection import train_test_split
19+
from sklearn.metrics import (
20+
get_scorer,
21+
mean_squared_error,
22+
r2_score,
23+
)
1924
from sklearn.pipeline import make_pipeline
2025
from sklearn.preprocessing import KBinsDiscretizer
2126
from sklearn.preprocessing import OneHotEncoder
@@ -25,7 +30,6 @@
2530
from sklearn.utils._testing import _convert_container
2631

2732

28-
2933
@pytest.mark.parametrize("n_jobs", [1, 2])
3034
def test_permutation_importance_correlated_feature_regression(n_jobs):
3135
# Make sure that feature highly correlated to the target have a higher
@@ -435,3 +439,46 @@ def my_scorer(estimator, X, y):
435439
scoring=my_scorer,
436440
n_repeats=1,
437441
sample_weight=w)
442+
443+
444+
@pytest.mark.parametrize(
445+
"list_single_scorer, multi_scorer",
446+
[
447+
(["r2", "neg_mean_squared_error"], ["r2", "neg_mean_squared_error"]),
448+
(
449+
["r2", "neg_mean_squared_error"],
450+
{
451+
"r2": get_scorer("r2"),
452+
"neg_mean_squared_error": get_scorer("neg_mean_squared_error"),
453+
},
454+
),
455+
(
456+
["r2", "neg_mean_squared_error"],
457+
lambda estimator, X, y: {
458+
"r2": r2_score(y, estimator.predict(X)),
459+
"neg_mean_squared_error": -mean_squared_error(
460+
y, estimator.predict(X)
461+
),
462+
},
463+
),
464+
],
465+
)
466+
def test_permutation_importance_multi_metric(list_single_scorer, multi_scorer):
467+
# Test permutation importance when scoring contains multiple scorers
468+
469+
# Creating some data and estimator for the permutation test
470+
x, y = make_regression(n_samples=500, n_features=10, random_state=0)
471+
lr = LinearRegression().fit(x, y)
472+
473+
multi_importance = permutation_importance(
474+
lr, x, y, random_state=1, scoring=multi_scorer, n_repeats=2
475+
)
476+
assert set(multi_importance.keys()) == set(list_single_scorer)
477+
478+
for scorer in list_single_scorer:
479+
multi_result = multi_importance[scorer]
480+
single_result = permutation_importance(
481+
lr, x, y, random_state=1, scoring=scorer, n_repeats=2
482+
)
483+
484+
assert_allclose(multi_result.importances, single_result.importances)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.