Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c3facb9

Browse filesBrowse files
azihnaAlihan Zihnaglemaitre
committed
EXA improve example of forest feature importances (scikit-learn#19377)
Co-authored-by: Alihan Zihna <a.zihna@ckhgbdp.onmicrosoft.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 1fe2dc5 commit c3facb9
Copy full SHA for c3facb9

File tree

Expand file treeCollapse file tree

1 file changed

+98
-48
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+98
-48
lines changed
+98-48Lines changed: 98 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,110 @@
11
"""
2-
=========================================
3-
Feature importances with forests of trees
4-
=========================================
2+
==========================================
3+
Feature importances with a forest of trees
4+
==========================================
55
6-
This examples shows the use of forests of trees to evaluate the importance of
7-
features on an artificial classification task. The red bars are
8-
the impurity-based feature importances of the forest,
9-
along with their inter-trees variability.
6+
This example shows the use of a forest of trees to evaluate the importance of
7+
features on an artificial classification task. The blue bars are the feature
8+
importances of the forest, along with their inter-trees variability represented
9+
by the error bars.
1010
1111
As expected, the plot suggests that 3 features are informative, while the
1212
remaining are not.
13-
14-
.. warning::
15-
Impurity-based feature importances can be misleading for high cardinality
16-
features (many unique values). See
17-
:func:`sklearn.inspection.permutation_importance` as an alternative.
1813
"""
1914
print(__doc__)
20-
21-
import numpy as np
2215
import matplotlib.pyplot as plt
2316

17+
# %%
18+
# Data generation and model fitting
19+
# ---------------------------------
20+
# We generate a synthetic dataset with only 3 informative features. We will
21+
# explicitly not shuffle the dataset to ensure that the informative features
22+
# will correspond to the three first columns of X. In addition, we will split
23+
# our dataset into training and testing subsets.
2424
from sklearn.datasets import make_classification
25-
from sklearn.ensemble import ExtraTreesClassifier
26-
27-
# Build a classification task using 3 informative features
28-
X, y = make_classification(n_samples=1000,
29-
n_features=10,
30-
n_informative=3,
31-
n_redundant=0,
32-
n_repeated=0,
33-
n_classes=2,
34-
random_state=0,
35-
shuffle=False)
36-
37-
# Build a forest and compute the impurity-based feature importances
38-
forest = ExtraTreesClassifier(n_estimators=250,
39-
random_state=0)
40-
41-
forest.fit(X, y)
25+
from sklearn.model_selection import train_test_split
26+
27+
X, y = make_classification(
28+
n_samples=1000, n_features=10, n_informative=3, n_redundant=0,
29+
n_repeated=0, n_classes=2, random_state=0, shuffle=False)
30+
X_train, X_test, y_train, y_test = train_test_split(
31+
X, y, stratify=y, random_state=42)
32+
33+
# %%
34+
# A random forest classifier will be fitted to compute the feature importances.
35+
from sklearn.ensemble import RandomForestClassifier
36+
37+
feature_names = [f'feature {i}' for i in range(X.shape[1])]
38+
forest = RandomForestClassifier(random_state=0)
39+
forest.fit(X_train, y_train)
40+
41+
# %%
42+
# Feature importance based on mean decrease in impurity
43+
# -----------------------------------------------------
44+
# Feature importances are provided by the fitted attribute
45+
# `feature_importances_` and they are computed as the mean and standard
46+
# deviation of accumulation of the impurity decrease within each tree.
47+
#
48+
# .. warning::
49+
# Impurity-based feature importances can be misleading for high cardinality
50+
# features (many unique values). See :ref:`permutation_importance` as
51+
# an alternative below.
52+
import time
53+
import numpy as np
54+
55+
start_time = time.time()
4256
importances = forest.feature_importances_
43-
std = np.std([tree.feature_importances_ for tree in forest.estimators_],
44-
axis=0)
45-
indices = np.argsort(importances)[::-1]
46-
47-
# Print the feature ranking
48-
print("Feature ranking:")
49-
50-
for f in range(X.shape[1]):
51-
print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))
52-
53-
# Plot the impurity-based feature importances of the forest
54-
plt.figure()
55-
plt.title("Feature importances")
56-
plt.bar(range(X.shape[1]), importances[indices],
57-
color="r", yerr=std[indices], align="center")
58-
plt.xticks(range(X.shape[1]), indices)
59-
plt.xlim([-1, X.shape[1]])
57+
std = np.std([
58+
tree.feature_importances_ for tree in forest.estimators_], axis=0)
59+
elapsed_time = time.time() - start_time
60+
61+
print(f"Elapsed time to compute the importances: "
62+
f"{elapsed_time:.3f} seconds")
63+
64+
# %%
65+
# Let's plot the impurity-based importance.
66+
import pandas as pd
67+
forest_importances = pd.Series(importances, index=feature_names)
68+
69+
fig, ax = plt.subplots()
70+
forest_importances.plot.bar(yerr=std, ax=ax)
71+
ax.set_title("Feature importances using MDI")
72+
ax.set_ylabel("Mean decrease in impurity")
73+
fig.tight_layout()
74+
75+
# %%
76+
# We observe that, as expected, the three first features are found important.
77+
#
78+
# Feature importance based on feature permutation
79+
# -----------------------------------------------
80+
# Permutation feature importance overcomes limitations of the impurity-based
81+
# feature importance: they do not have a bias toward high-cardinality features
82+
# and can be computed on a left-out test set.
83+
from sklearn.inspection import permutation_importance
84+
85+
start_time = time.time()
86+
result = permutation_importance(
87+
forest, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2)
88+
elapsed_time = time.time() - start_time
89+
print(f"Elapsed time to compute the importances: "
90+
f"{elapsed_time:.3f} seconds")
91+
92+
forest_importances = pd.Series(result.importances_mean, index=feature_names)
93+
94+
# %%
95+
# The computation for full permutation importance is more costly. Features are
96+
# shuffled n times and the model refitted to estimate the importance of it.
97+
# Please see :ref:`permutation_importance` for more details. We can now plot
98+
# the importance ranking.
99+
100+
fig, ax = plt.subplots()
101+
forest_importances.plot.bar(yerr=result.importances_std, ax=ax)
102+
ax.set_title("Feature importances using permutation on full model")
103+
ax.set_ylabel("Mean accuracy decrease")
104+
fig.tight_layout()
60105
plt.show()
106+
107+
# %%
108+
# The same features are detected as most important using both methods. Although
109+
# the relative importances vary. As seen on the plots, MDI is less likely than
110+
# permutation importance to fully omit a feature.

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.