-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
DOC improve plot_grid_search_refit_callable.py and add links #30990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
These are the kinds of plots which I really think we should have much easier ways to do, either as Displays in sklearn, or in skore, not sure. I've changed the plots a bit, and I think they're better/more informative of what's happening cc @glemaitre |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for giving this example some love. Here is a first pass of suggestions related to the use of stratification when estimating the standard deviation of cross-validation score. I plan to do a second pass tomorrow.
# | ||
# We use GridSearchCV with our custom `best_low_complexity` function as the refit | ||
# parameter. This function will select the model with the fewest PCA components that | ||
# still performs within one standard deviation of the best model. | ||
|
||
grid = GridSearchCV( | ||
pipe, | ||
cv=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's recommend users to use non-stratified CV for such a use case. Using stratification can make the standard deviation of the validation scores degenerate on imbalanced data. Here the dataset is balanced, so stratification should have no impact. However, since this example might be copy-pasted to be reused on imbalanced data, I think it's safer to advise a less brittle way to estimate epistemic uncertainty.
cv=10, | |
# Use a non-stratified CV strategy to make sure that the inter-fold | |
# standard deviation of the test scores is informative. | |
cv=ShuffleSplit(n_splits=10, random_state=0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, using more iterations yields a smoother curves that looks better and should also lead to a more stable selection of the best number of PCA components:
cv=ShuffleSplit(n_splits=30, test_size=0.1, random_state=42)
but it makes the example run a bit slower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah 30 is too slow for the CI I'd say.
# | ||
# We create a pipeline with two steps: | ||
# 1. Dimensionality reduction using PCA | ||
# 2. Classification using LinearSVC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I would always recommend using LogisticRegression
over linear SVC nowadays. Those models have similar ROC-AUC capabilities, but only LR can output interpretable confidence scores with predict_proba
(and evaluated with a proper scoring rule such as Brier score or log loss).
Furthermore, Liblinear sample_weight
support seems to be broken in subtle ways that might be difficult to fix, so I would rather stop implicitly recommending this model in our examples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried the example with LogisticRegression
and the results are very similar but it's required to pass a larger max_iter
value to avoid warnings (e.g. max_iter=1000
).
At the end of the day, I'm under the impression that we are doing a validation curve here. Right now, the |
@glemaitre That's a very different plot though, here we're simply plotting the measured metric(s). I do agree that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Just wondering, if one could make more use of results_df
and shorten code.
Since the grid search results can include combinations of more than one hyperparameter at once, I am not sure how that would work. I agree, let's keep this discussion for a follow-up issue to avoid side-tracking the review of this example. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more feedback. LGTM overall.
# selection of the "best" model is desired. | ||
|
||
# Adjust layout and display the figure | ||
plt.tight_layout() | ||
plt.show() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need that last cell now that the example has been converted to a notebook style example.
Also, passing constrained_layout=True
to the plt.subplots
call above is likely a better solution to fix overlapping label and axis issues in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to keep the plt.show()
to actually show the plot when running the example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think with 'new' sphinx-gallery (>0.5.0) you don't need it for the plot to show, but it can be useful to avoid the text output. You could also use _ = plt.tight_layout()
to avoid text output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When running locally, that doesn't show the plot.
# | ||
# We use GridSearchCV with our custom `best_low_complexity` function as the refit | ||
# parameter. This function will select the model with the fewest PCA components that | ||
# still performs within one standard deviation of the best model. | ||
|
||
grid = GridSearchCV( | ||
pipe, | ||
cv=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, using more iterations yields a smoother curves that looks better and should also lead to a more stable selection of the best number of PCA components:
cv=ShuffleSplit(n_splits=30, test_size=0.1, random_state=42)
but it makes the example run a bit slower.
Also, I would not be opposed to collapsing the 2 subplots into ones that display everything at once but using a bigger figure size. |
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# selection of the "best" model is desired. | ||
|
||
# Adjust layout and display the figure | ||
plt.tight_layout() | ||
plt.show() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to keep the plt.show()
to actually show the plot when running the example
# | ||
# We use GridSearchCV with our custom `best_low_complexity` function as the refit | ||
# parameter. This function will select the model with the fewest PCA components that | ||
# still performs within one standard deviation of the best model. | ||
|
||
grid = GridSearchCV( | ||
pipe, | ||
cv=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah 30 is too slow for the CI I'd say.
Are we happy with the new plot? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merging since there are no more unresolved comments.
# selection of the "best" model is desired. | ||
|
||
# Adjust layout and display the figure | ||
plt.tight_layout() | ||
plt.show() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When running locally, that doesn't show the plot.
Towards #30621
This adds links to the example, as well as improving the example itself.
I wonder if the plots can be done easier, either with polars or matplotlib, or other libs. I'm not really a plotting person.
Maybe @lucyleeow or @MarcoGorelli would have an idea (this uses
polars
)cc @StefanieSenger
It also makes the docstrings for
refit
more consistent with one another (which is an interesting case for comparing docstrings efforts @lucyleeow )