Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

DOC improve plot_grid_search_refit_callable.py and add links #30990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 20, 2025

Conversation

adrinjalali
Copy link
Member

Towards #30621

This adds links to the example, as well as improving the example itself.

I wonder if the plots can be done easier, either with polars or matplotlib, or other libs. I'm not really a plotting person.

Maybe @lucyleeow or @MarcoGorelli would have an idea (this uses polars)

cc @StefanieSenger

It also makes the docstrings for refit more consistent with one another (which is an interesting case for comparing docstrings efforts @lucyleeow )

Copy link

github-actions bot commented Mar 13, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 8193ce6. Link to the linter CI: here

@adrinjalali adrinjalali marked this pull request as ready for review March 13, 2025 15:06
sklearn/model_selection/_search_successive_halving.py Outdated Show resolved Hide resolved
@adrinjalali
Copy link
Member Author

adrinjalali commented Mar 14, 2025

These are the kinds of plots which I really think we should have much easier ways to do, either as Displays in sklearn, or in skore, not sure.

I've changed the plots a bit, and I think they're better/more informative of what's happening

image

cc @glemaitre

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for giving this example some love. Here is a first pass of suggestions related to the use of stratification when estimating the standard deviation of cross-validation score. I plan to do a second pass tomorrow.

#
# We use GridSearchCV with our custom `best_low_complexity` function as the refit
# parameter. This function will select the model with the fewest PCA components that
# still performs within one standard deviation of the best model.

grid = GridSearchCV(
pipe,
cv=10,
Copy link
Member

@ogrisel ogrisel Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's recommend users to use non-stratified CV for such a use case. Using stratification can make the standard deviation of the validation scores degenerate on imbalanced data. Here the dataset is balanced, so stratification should have no impact. However, since this example might be copy-pasted to be reused on imbalanced data, I think it's safer to advise a less brittle way to estimate epistemic uncertainty.

Suggested change
cv=10,
# Use a non-stratified CV strategy to make sure that the inter-fold
# standard deviation of the test scores is informative.
cv=ShuffleSplit(n_splits=10, random_state=0),

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, using more iterations yields a smoother curves that looks better and should also lead to a more stable selection of the best number of PCA components:

cv=ShuffleSplit(n_splits=30, test_size=0.1, random_state=42)

but it makes the example run a bit slower.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah 30 is too slow for the CI I'd say.

#
# We create a pipeline with two steps:
# 1. Dimensionality reduction using PCA
# 2. Classification using LinearSVC
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would always recommend using LogisticRegression over linear SVC nowadays. Those models have similar ROC-AUC capabilities, but only LR can output interpretable confidence scores with predict_proba (and evaluated with a proper scoring rule such as Brier score or log loss).

Furthermore, Liblinear sample_weight support seems to be broken in subtle ways that might be difficult to fix, so I would rather stop implicitly recommending this model in our examples.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the example with LogisticRegression and the results are very similar but it's required to pass a larger max_iter value to avoid warnings (e.g. max_iter=1000).

@glemaitre
Copy link
Member

These are the kinds of plots which I really think we should have much easier ways to do, either as Displays in sklearn, or in skore, not sure.

At the end of the day, I'm under the impression that we are doing a validation curve here. Right now, the ValidationCurveDisplay.from_estimator takes an estimator and perform the cross-validation. I'm under the impression that ValidationCurveDisplay.from_cv_results should take the search_cv.cv_results_ and create this plot.

@adrinjalali
Copy link
Member Author

@glemaitre That's a very different plot though, here we're simply plotting the measured metric(s). I do agree that ValidationCurveDisplay.from_cv_results should exist, but that's a separate issue I'd say.

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
Just wondering, if one could make more use of results_df and shorten code.

@ogrisel
Copy link
Member

ogrisel commented Mar 18, 2025

At the end of the day, I'm under the impression that we are doing a validation curve here. Right now, the ValidationCurveDisplay.from_estimator takes an estimator and perform the cross-validation. I'm under the impression that ValidationCurveDisplay.from_cv_results should take the search_cv.cv_results_ and create this plot.

Since the grid search results can include combinations of more than one hyperparameter at once, I am not sure how that would work. I agree, let's keep this discussion for a follow-up issue to avoid side-tracking the review of this example.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more feedback. LGTM overall.

# selection of the "best" model is desired.

# Adjust layout and display the figure
plt.tight_layout()
plt.show()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need that last cell now that the example has been converted to a notebook style example.

Also, passing constrained_layout=True to the plt.subplots call above is likely a better solution to fix overlapping label and axis issues in general.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to keep the plt.show() to actually show the plot when running the example

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with 'new' sphinx-gallery (>0.5.0) you don't need it for the plot to show, but it can be useful to avoid the text output. You could also use _ = plt.tight_layout() to avoid text output

ref: https://sphinx-gallery.github.io/stable/faq.html#why-am-i-getting-text-output-for-matplotlib-functions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running locally, that doesn't show the plot.

#
# We use GridSearchCV with our custom `best_low_complexity` function as the refit
# parameter. This function will select the model with the fewest PCA components that
# still performs within one standard deviation of the best model.

grid = GridSearchCV(
pipe,
cv=10,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, using more iterations yields a smoother curves that looks better and should also lead to a more stable selection of the best number of PCA components:

cv=ShuffleSplit(n_splits=30, test_size=0.1, random_state=42)

but it makes the example run a bit slower.

@ogrisel
Copy link
Member

ogrisel commented Mar 18, 2025

Also, I would not be opposed to collapsing the 2 subplots into ones that display everything at once but using a bigger figure size.

adrinjalali and others added 5 commits March 18, 2025 16:52
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Copy link
Member Author

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New plot:

image

# selection of the "best" model is desired.

# Adjust layout and display the figure
plt.tight_layout()
plt.show()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to keep the plt.show() to actually show the plot when running the example

#
# We use GridSearchCV with our custom `best_low_complexity` function as the refit
# parameter. This function will select the model with the fewest PCA components that
# still performs within one standard deviation of the best model.

grid = GridSearchCV(
pipe,
cv=10,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah 30 is too slow for the CI I'd say.

@adrinjalali
Copy link
Member Author

Are we happy with the new plot?

@ogrisel
Copy link
Member

ogrisel commented Mar 24, 2025

The new plot looks good but there is an HTML rendering problem in the end of the example:

image

Otherwise, LGTM. Thanks.

Copy link
Member Author

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging since there are no more unresolved comments.

# selection of the "best" model is desired.

# Adjust layout and display the figure
plt.tight_layout()
plt.show()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running locally, that doesn't show the plot.

@adrinjalali adrinjalali merged commit 18cdea7 into scikit-learn:main May 20, 2025
36 checks passed
@adrinjalali adrinjalali deleted the adrin-11 branch May 20, 2025 11:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.