Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Adding a pruning method to the tree #941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

sgenoud
Copy link

@sgenoud sgenoud commented Jul 10, 2012

I have added a pruning method to the decision trees method. The idea with decision tree is to build a huge one, prune it via a weakest link algorithm until it reaches a size that is reasonable (neither overfitting nor under fitting).

I also have build a helper function, cv_scores_vs_n_leaves, that computes the cross validated scores for different sizes of the tree. This can be plotted with a function such as

def plot_cross_validated_scores(scores):
    """Plots the cross validated scores versus the number of leaves of trees"""
    import matplotlib.pyplot as plt
    means = np.array([np.mean(s) for s in scores])
    stds = np.array([np.std(s) for s in scores])

    x = range(len(scores)+1, 1, -1)

    plt.plot(x, means)
    plt.plot(x, means+stds, lw=1, c='0.7')
    plt.plot(x, means-stds, lw=1, c='0.7')

    plt.xlabel('Number of leaves')
    plt.ylabel('Cross validated score')

Then we choose which size is the best for the data.

Just a couple of notes:

  • I need to add some tests (I am not sure how to make them)
  • This also needs to be documented

Before doing that work, I would gladly have some feedback on these modifications

----------
tree : binary tree object
The binary tree for which to compute the complexity costs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The object doesn't need to be listed as a parameter here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I forgot to remove that when I refactored my code.

@jakevdp
Copy link
Member

jakevdp commented Jul 10, 2012

Nice addition. I haven't had a chance to try it out in detail, but I read over the code and it looks good.
One suggestion: in the pruning_order function, it would be more efficient if we were able to pass an argument like max_to_prune which would allow you to specify the maximum number of nodes you're interested in pruning. As it's currently written, I think the function will sort all the nodes each time it's called.

@@ -266,6 +291,105 @@ def _add_leaf(self, parent, is_left_child, value, error, n_samples):

return node_id

def _copy(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to use clone here? Not sure if that copies the arrays, though. But I'm sure there is a method that does (for serialization for example).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clone clones the parameters of the Estimator object. Tree results from the fit of a DecisionTree -- it is not copied. Also DecisionTree has a Tree and is an Estimator, but a Tree is not an estimator (it inherits directly from object).

If someone knows a copying function, I will gladly use it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy or deepcopy from the copy module.

@amueller
Copy link
Member

Thanks for your contribution.
I didn't look at the code in detail, so just some very general remarks.

It would be nice to have some mentioning in the narrative docs and maybe extent an already existing example to show how the pruning works and what effect is hast.

I'm not sure what would be a good test for the pruning but maybe you can come up with something, for example testing on some toy data with known expected outcome.

Looking forward to seeing this in action :)

t_nodes = _get_terminal_nodes(children)
g_i = tree.init_error[t_nodes] - tree.best_error[t_nodes]

#alpha_i = np.min(g_i) / len(t_nodes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this isn't needed, should remove it...

@jakevdp
Copy link
Member

jakevdp commented Jul 10, 2012

I agree narrative docs and an example would be very helpful. You could modify these:

You could add a fit to the plot which uses a pruned tree. Hopefully it would do better than the green line, but not over-fit as much as the red line.

Also, a new example with your plot_cross_validated_scores() function would be very useful. It could go in the same directory, and be included in the same narrative document.

Sorry to ask for more work when you've already done so much! Let us know if you need any help. I'm excited to see the results! 😁


Parameters
----------
n_leaves : binary tree object
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this an int, what does it mean to be a binary tree object?

@sgenoud
Copy link
Author

sgenoud commented Jul 11, 2012

Thanks for all your feedback!

I will for sure make some docs, but I wanted to see if what I did was worth pulling before doing more work.

About my helper functions, one little details: I use _get_leaves in the new property leaves that I added to tree (and use it to compute the number of nodes to prune). I usually like to have small helper functions, it makes things easier for me to read afterwards. Nevertheless if merging them in seems a better design to most of you, I'll do it.

Also it would be nice if one of you plays a bit with the feature to have an external "test" on the function. I have used it for my own needs (and compared it in one case with some equivalent R function), but external confirmation is usually a good thing.

Steve Genoud added 3 commits July 11, 2012 14:19

def prune(self, n_leaves):
"""
Prunes the tree in order to obtain the optimal tree with n_leaves
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather say the optimal subtree.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@amueller
Copy link
Member

I just had a quick look at the plot_prune_boston.py example.
I noticed a couple of things:

  • You should add plt.show() at the end, so that people that just call the example also see something.
  • From the plot it is not immediately clear what one should see. I think you should write a small paragraph about what the example shows.
  • If I interpret the example correctly, the more leaves the better. This means not pruning is better then pruning (maybe 20 leaves is still pruned but I don't see performance deteriorating), right?
  • Finally, please run pep8 over it, there are some minor issues.

@amueller
Copy link
Member

The synthetic example is pretty cool on the other hand and nicely illustrates the effect of pruning.

It leaves me with one question, though: are there cases when pruning is better than regularizing via "min_leaf" or "max_depth"?
The example shows nicely that regularizing is better than not regularizing, but it is not immediately clear what the benefit of one or the other way of regularizing is. Do you think it is possible to include this in the example somehow? At the moment, the example looks like "if you don't prune, you overfit' to me.

@amueller
Copy link
Member

I don't know what the other tree-growers think, but I feel having a separate "prune" method breaks the API pretty badly. I think n_leaves should just be an additional parameter and prune should be _prune. If one want's more details, one can use cv_scores_vs_n_leaves. Which I would by the way rename to something like prune_path or similar, as this seems quite related to other regularization path functions.

@amueller
Copy link
Member

I just realized that having a separate prune method means that the pruning can not be used together with GridSearchCV, which is basically a no-go.

@sgenoud
Copy link
Author

sgenoud commented Jul 12, 2012

If I interpret the example correctly, the more leaves the better. This means not pruning is better then pruning (maybe 20 leaves is still pruned but I don't see performance deteriorating), right?

I think the heuristics in this case is to go for the smallest tree with the value of the plateau. Actually R automatically uses such an heuristic. As you said, I should explain more how to read the graphs.

It leaves me with one question, though: are there cases when pruning is better than regularizing via "min_leaf" or "max_depth"?

I should probably show how the trees can be different in a case where we prune or not. Via pruning the resulting tree is "more optimal". It is possible that, while growing we do not go into a branch because the next step does not improve much the score, while the following step would improve it greatly -- but will not because this path is not chosen.

I think n_leaves should just be an additional parameter and prune should be _prune.

What do you think of a mixed approach? We could add an auto_prune argument to the fit method (in addition to n_leaves) and keep prune a method. This would mean that by default trees are grown and pruned, but users who want to be more picky about it can prune the trees themselves. Also, GridSearchCV would work.

Note that this would change the interface of this particular method (by that I mean that the default behaviour would become linked with pruning, while it was not so far). We would need to discuss a bit further what the default behaviour is (do we ask only for n_leaves by default grow a three with 3 more depth levels?).

Ok, if I sum it up, to follow better the API we need to change prune from an action that is made on a fitted Estimator to the default method for which trees are sized. This would have consequences on how we present it in the docs. I would propose the following:

  • Introducing a section that discusses methods of sizing a tree
  • Another section about the prune_path with the synthetic example

@amueller
Copy link
Member

@sgenoud Thanks for your quick feedback. Before you get to work, I would prefer if someone else would also voice there opinion, so that you don't do any unnecessary work regarding refactoring.

If you could find an example that illustrates the benefits of pruning compared to other regularization methods, that would be great. For the synthetic example that you made, I fear that "more optimal" means "more overfitted" and for example "min_samples_leaf" does better. Maybe you should also mention somewhere that the pruning fits more strongly to the data.

As a default behavior, I would have imagined to have no pruning, to be backward compatible and also because the ensemble module doesn't need pruning.

@amueller
Copy link
Member

@glouppe @pprett @bdholt1 comments on this?

@glouppe
Copy link
Contributor

glouppe commented Jul 23, 2012

@amueller Not that much, basically instead of stoppping when the tree has exactly n_leaves, you instead pick the subtree that minimizes C(alpha, T). The same sequence of subtrees can be used if I recall correclty. (I admit though that I am not really an expert of post-pruning methods, I never usethem.)

@glouppe
Copy link
Contributor

glouppe commented Jul 23, 2012

@sgenoud Yes I fully agree with you. However don't you think it would be better to implement the method that textbooks describe? Only not to confuse those that may have some background knowledge.

@sgenoud
Copy link
Author

sgenoud commented Jul 23, 2012

@glouppe It is probably a question of taste in the end. Is there a scikit-learn policy for this kind of choice?

I would propose to make the point of alpha and n_leaves being equivalent in the documentation. But to keep a simpler interface for the function based only on the number of leaves. Users familiar with the method will understand quickly that there is not much difference and be confused until they read the docs.

Finally, we could also argue that it would make pre and post pruning more similar.

@amueller
Copy link
Member

I would favor n_leaves as it is much more intuitive. Also it is clear that the smallest change that has an effect is 1, and as @sgenoud said, it is more interpretable in the context of the other parameters.

@glouppe
Copy link
Contributor

glouppe commented Jul 23, 2012

Okay then, I am fine with that.

@sgenoud
Copy link
Author

sgenoud commented Jul 24, 2012

I have tried to integrate your feedback, there is still two things to do:

  • edit the documentation
  • merge it with the master (the Tree object has been moved to a cython file I think)

@sgenoud
Copy link
Author

sgenoud commented Jul 25, 2012

@glouppe would you mind merging my modifications of the Tree object with your refactoring? I am no Cython expert and you are the one who refactored the object (and therefore know it well).

@glouppe
Copy link
Contributor

glouppe commented Jul 26, 2012

@sgenoud I can do it, but it'll unfortunately have to wait until mid-Augustus. I am soon leaving for holidays and have a few remaining things to handle before.

@amueller
Copy link
Member

@sgenoud maybe you can work on the doc in the meantime?

@amueller
Copy link
Member

Anything new?

@mrjbq7
Copy link
Contributor

mrjbq7 commented Nov 5, 2012

Need help with this? @sgenoud, what is the status of the patch?

@sgenoud
Copy link
Author

sgenoud commented Nov 11, 2012

Hi guys, sorry I have started a new job that takes me a lot of time. I will have more for this in December normally. I'll keep you posted

@sgenoud sgenoud closed this Nov 11, 2012
@sgenoud sgenoud reopened this Nov 11, 2012
@erg
Copy link
Contributor

erg commented Apr 3, 2013

This patch is basically completely busted now because it works on the old tree before things were ported to cython trees.

@aflaxman
Copy link
Contributor

aflaxman commented May 2, 2013

I would like to do some pruning, so I've made an attempt to update the code in this pull request. I'm not sure if it is the best way to do things now, but at least it brings us close to where we were before.

The updated code is here: https://github.com/aflaxman/scikit-learn/tree/tree_pruning

Is there still interest in this?

@amueller
Copy link
Member

amueller commented May 7, 2013

Yes, I think there is interest as long as it doesn't add overhead that can't be avoided for the forests.
I think having a pruning option for single trees would be great.
Maybe you should submit your own pull request with a reference to this one?

@jessrosenfield
Copy link

This seems to have gone stale. As a user of sklearn I'd love see pruning for decision trees. Are there any updates?

@thomas4g
Copy link

thomas4g commented Feb 5, 2016

Echoing @jessrosenfield on this (I'm guessing she's working on the same machine learning project I am...) - any updates?

@jakevdp
Copy link
Member

jakevdp commented Feb 5, 2016

The tree and forest code has been changed considerably since this work was started; any new effort toward pruning would essentially have to start from scratch. Perhaps we should close this PR?

@amueller
Copy link
Member

amueller commented Oct 7, 2016

Closing in favor of #6557

@amueller amueller closed this Oct 7, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Morty Proxy This is a proxified and sanitized view of the page, visit original site.