Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Adding a pruning method to the tree #941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions 77 doc/modules/tree.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ Some advantages of decision trees are:
The disadvantages of decision trees include:

- Decision-tree learners can create over-complex trees that do not
generalise the data well. This is called overfitting. Mechanisms
such as pruning (not currently supported), setting the minimum
number of samples required at a leaf node or setting the maximum
depth of the tree are necessary to avoid this problem.
generalise the data well. This is called overfitting. Mechanisms such as
pruning, setting the minimum number of samples required at a leaf node or
setting the maximum depth of the tree are necessary to avoid this
problem.

- Decision trees can be unstable because small variations in the
data might result in a completely different tree being generated.
Expand Down Expand Up @@ -183,6 +183,75 @@ instead of integer values::
* :ref:`example_tree_plot_tree_regression.py`



.. _tree_pruning:

Pruning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, this section looks written as if post-pruning was the only way to prune a tree. I would rather introduce both pre-pruning (using min_samples_split or min_samples_leaf) and post-pruning and compare them both. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't call "min_samples_split" pruning but I think think comparing these different regularization methods would be good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it is called a pre-pruning method in the literature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really? Ok then.

=======

A common approach to get the best possible tree is to grow a huge tree (for
instance with ``max_depth=8``) and then prune it to an optimum size. As well as
providing a `prune` method for both :class:`DecisionTreeRegressor` and
:class:`DecisionTreeClassifier`, the function ``prune_path`` is useful
to find what the optimum size is for a tree.

The prune method just takes as argument the number of leaves the fitted tree
should have (an int)::

>>> from sklearn.datasets import load_boston
>>> from sklearn import tree
>>> boston = load_boston()
>>> clf = tree.DecisionTreeRegressor(max_depth=8)
>>> clf = clf.fit(boston.data, boston.target)
>>> clf = clf.prune(8)

In order to find the optimal number of leaves we can use cross validated scores
on the data::

>>> from sklearn.datasets import load_boston
>>> from sklearn import tree
>>> boston = load_boston()
>>> clf = tree.DecisionTreeRegressor(max_depth=8)
>>> scores = tree.prune_path(clf, boston.data, boston.target,
... max_n_leaves=20, n_iterations=10, random_state=0)

In order to plot the scores one can use the following function::

def plot_pruned_path(scores, with_std=True):
"""Plots the cross validated scores versus the number of leaves of trees"""
import matplotlib.pyplot as plt
means = np.array([np.mean(s) for s in scores])
stds = np.array([np.std(s) for s in scores]) / np.sqrt(len(scores[1]))

x = range(len(scores) + 1, 1, -1)

plt.plot(x, means)
if with_std:
plt.plot(x, means + 2 * stds, lw=1, c='0.7')
plt.plot(x, means - 2 * stds, lw=1, c='0.7')

plt.xlabel('Number of leaves')
plt.ylabel('Cross validated score')


For instance, using the Boston dataset we obtain such a graph

.. figure:: ../auto_examples/tree/images/plot_prune_boston_1.png
:target: ../auto_examples/tree/plot_prune_boston.html
:align: center
:scale: 75

Here we see clearly that the optimum number of leaves is between 6 and 9. After
that additional leaves do not improve (or diminish) the score of the cross
validation.

.. topic:: Examples:

* :ref:`example_tree_plot_prune_boston.py`
* :ref:`example_tree_plot_overfitting_cv.py`



.. _tree_multioutput:

Multi-output problems
Expand Down
72 changes: 72 additions & 0 deletions 72 examples/tree/plot_overfitting_cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
====================================================
Comparison of cross validated score with overfitting
====================================================

These two plots compare the cross validated score of a the regression of
a simple function. We see that before the maximum value of 7 the regression is
far for the real function. On the other hand, for higher number of leaves we
clearly overfit.

"""
print __doc__

import numpy as np
from sklearn import tree


def plot_pruned_path(scores, with_std=True):
"""Plots the cross validated scores versus the number of leaves of trees"""
import matplotlib.pyplot as plt
means = np.array([np.mean(s) for s in scores])
stds = np.array([np.std(s) for s in scores]) / np.sqrt(len(scores[1]))

x = range(len(scores) + 1, 1, -1)

plt.plot(x, means)
if with_std:
plt.plot(x, means + 2 * stds, lw=1, c='0.7')
plt.plot(x, means - 2 * stds, lw=1, c='0.7')

plt.xlabel('Number of leaves')
plt.ylabel('Cross validated score')


# Create a random dataset
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[1::5] += 3 * (0.5 - rng.rand(16))


clf = tree.DecisionTreeRegressor(max_depth=20)
scores = tree.prune_path(clf, X, y, max_n_leaves=20,
n_iterations=100, random_state=0)
plot_pruned_path(scores)

clf = tree.DecisionTreeRegressor(max_depth=20, n_leaves=15)
clf.fit(X, y)
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]

#Prepare the different pruned level
y_15 = clf.predict(X_test)

clf = clf.prune(6)
y_7 = clf.predict(X_test)

clf = clf.prune(2)
y_2 = clf.predict(X_test)

# Plot the results
import pylab as pl

pl.figure()
pl.scatter(X, y, c="k", label="data")
pl.plot(X_test, y_2, c="g", label="n_leaves=2", linewidth=2)
pl.plot(X_test, y_7, c="b", label="n_leaves=7", linewidth=2)
pl.plot(X_test, y_15, c="r", label="n_leaves=15", linewidth=2)
pl.xlabel("data")
pl.ylabel("target")
pl.title("Decision Tree Regression with levels of pruning")
pl.legend()
pl.show()
39 changes: 39 additions & 0 deletions 39 examples/tree/plot_prune_boston.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
============================================
Cross validated scores of the boston dataset
============================================

"""
print __doc__

import numpy as np
from sklearn.datasets import load_boston
from sklearn import tree


def plot_pruned_path(scores, with_std=True):
"""Plots the cross validated scores versus the number of leaves of trees"""
import matplotlib.pyplot as plt
means = np.array([np.mean(s) for s in scores])
stds = np.array([np.std(s) for s in scores]) / np.sqrt(len(scores[1]))

x = range(len(scores) + 1, 1, -1)

plt.plot(x, means)
if with_std:
plt.plot(x, means + 2 * stds, lw=1, c='0.7')
plt.plot(x, means - 2 * stds, lw=1, c='0.7')

plt.xlabel('Number of leaves')
plt.ylabel('Cross validated score')


boston = load_boston()
clf = tree.DecisionTreeRegressor(max_depth=8)

#Compute the cross validated scores
scores = tree.prune_path(clf, boston.data, boston.target,
max_n_leaves=20, n_iterations=10,
random_state=0)

plot_pruned_path(scores)
1 change: 1 addition & 0 deletions 1 sklearn/tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .tree import ExtraTreeClassifier
from .tree import ExtraTreeRegressor
from .tree import export_graphviz
from .tree import prune_path
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.