-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Adding a pruning method to the tree #941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
4a75a4a
Introduced pruning to decision trees
48fad0c
Merge branch 'master' of git://github.com/scikit-learn/scikit-learn
c51698a
Fixed documentation of pruning_order function
c139f6a
Added a max_to_prune argument to pruning_order
35a8b3d
Incorporated feedback
4310876
Merge branch 'master' of git://github.com/scikit-learn/scikit-learn
0409c05
Made n_output an optional value
89af373
Cleaned the function documentation (again)
45adf40
Use shuffle and split to cross validate
9e41d88
First draft of the documentation
92c6afb
pep8 corrections
acd4048
Renamed cv_scores_vs_n_leaves as pruned_path
sgenoud 98255d8
Added n_leaves to the DecisionTree estimator
sgenoud ea8e460
Moved the helper functions inside the Tree object
sgenoud File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
""" | ||
==================================================== | ||
Comparison of cross validated score with overfitting | ||
==================================================== | ||
|
||
These two plots compare the cross validated score of a the regression of | ||
a simple function. We see that before the maximum value of 7 the regression is | ||
far for the real function. On the other hand, for higher number of leaves we | ||
clearly overfit. | ||
|
||
""" | ||
print __doc__ | ||
|
||
import numpy as np | ||
from sklearn import tree | ||
|
||
|
||
def plot_pruned_path(scores, with_std=True): | ||
"""Plots the cross validated scores versus the number of leaves of trees""" | ||
import matplotlib.pyplot as plt | ||
means = np.array([np.mean(s) for s in scores]) | ||
stds = np.array([np.std(s) for s in scores]) / np.sqrt(len(scores[1])) | ||
|
||
x = range(len(scores) + 1, 1, -1) | ||
|
||
plt.plot(x, means) | ||
if with_std: | ||
plt.plot(x, means + 2 * stds, lw=1, c='0.7') | ||
plt.plot(x, means - 2 * stds, lw=1, c='0.7') | ||
|
||
plt.xlabel('Number of leaves') | ||
plt.ylabel('Cross validated score') | ||
|
||
|
||
# Create a random dataset | ||
rng = np.random.RandomState(1) | ||
X = np.sort(5 * rng.rand(80, 1), axis=0) | ||
y = np.sin(X).ravel() | ||
y[1::5] += 3 * (0.5 - rng.rand(16)) | ||
|
||
|
||
clf = tree.DecisionTreeRegressor(max_depth=20) | ||
scores = tree.prune_path(clf, X, y, max_n_leaves=20, | ||
n_iterations=100, random_state=0) | ||
plot_pruned_path(scores) | ||
|
||
clf = tree.DecisionTreeRegressor(max_depth=20, n_leaves=15) | ||
clf.fit(X, y) | ||
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis] | ||
|
||
#Prepare the different pruned level | ||
y_15 = clf.predict(X_test) | ||
|
||
clf = clf.prune(6) | ||
y_7 = clf.predict(X_test) | ||
|
||
clf = clf.prune(2) | ||
y_2 = clf.predict(X_test) | ||
|
||
# Plot the results | ||
import pylab as pl | ||
|
||
pl.figure() | ||
pl.scatter(X, y, c="k", label="data") | ||
pl.plot(X_test, y_2, c="g", label="n_leaves=2", linewidth=2) | ||
pl.plot(X_test, y_7, c="b", label="n_leaves=7", linewidth=2) | ||
pl.plot(X_test, y_15, c="r", label="n_leaves=15", linewidth=2) | ||
pl.xlabel("data") | ||
pl.ylabel("target") | ||
pl.title("Decision Tree Regression with levels of pruning") | ||
pl.legend() | ||
pl.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
============================================ | ||
Cross validated scores of the boston dataset | ||
============================================ | ||
|
||
""" | ||
print __doc__ | ||
|
||
import numpy as np | ||
from sklearn.datasets import load_boston | ||
from sklearn import tree | ||
|
||
|
||
def plot_pruned_path(scores, with_std=True): | ||
"""Plots the cross validated scores versus the number of leaves of trees""" | ||
import matplotlib.pyplot as plt | ||
means = np.array([np.mean(s) for s in scores]) | ||
stds = np.array([np.std(s) for s in scores]) / np.sqrt(len(scores[1])) | ||
|
||
x = range(len(scores) + 1, 1, -1) | ||
|
||
plt.plot(x, means) | ||
if with_std: | ||
plt.plot(x, means + 2 * stds, lw=1, c='0.7') | ||
plt.plot(x, means - 2 * stds, lw=1, c='0.7') | ||
|
||
plt.xlabel('Number of leaves') | ||
plt.ylabel('Cross validated score') | ||
|
||
|
||
boston = load_boston() | ||
clf = tree.DecisionTreeRegressor(max_depth=8) | ||
|
||
#Compute the cross validated scores | ||
scores = tree.prune_path(clf, boston.data, boston.target, | ||
max_n_leaves=20, n_iterations=10, | ||
random_state=0) | ||
|
||
plot_pruned_path(scores) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, this section looks written as if post-pruning was the only way to prune a tree. I would rather introduce both pre-pruning (using
min_samples_split
ormin_samples_leaf
) and post-pruning and compare them both. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't call "min_samples_split" pruning but I think think comparing these different regularization methods would be good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it is called a pre-pruning method in the literature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really? Ok then.