Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 18cdea7

Browse filesBrowse files
adrinjalalilorentzenchrogrisel
authored
DOC improve plot_grid_search_refit_callable.py and add links (#30990)
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 9b40cbc commit 18cdea7
Copy full SHA for 18cdea7

File tree

5 files changed

+322
-35
lines changed
Filter options

5 files changed

+322
-35
lines changed

‎doc/whats_new/v0.20.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v0.20.rst
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ Miscellaneous
445445

446446
- |API| Removed all mentions of ``sklearn.externals.joblib``, and deprecated
447447
joblib methods exposed in ``sklearn.utils``, except for
448-
:func:`utils.parallel_backend` and :func:`utils.register_parallel_backend`,
448+
`utils.parallel_backend` and `utils.register_parallel_backend`,
449449
which allow users to configure parallel computation in scikit-learn.
450450
Other functionalities are part of `joblib <https://joblib.readthedocs.io/>`_.
451451
package and should be used directly, by installing it.

‎doc/whats_new/v1.5.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.5.rst
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ Changelog
656656
- |API| :func:`utils.tosequence` is deprecated and will be removed in version 1.7.
657657
:pr:`28763` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
658658

659-
- |API| :class:`utils.parallel_backend` and :func:`utils.register_parallel_backend` are
659+
- |API| `utils.parallel_backend` and `utils.register_parallel_backend` are
660660
deprecated and will be removed in version 1.7. Use `joblib.parallel_backend` and
661661
`joblib.register_parallel_backend` instead.
662662
:pr:`28847` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

‎examples/model_selection/plot_grid_search_refit_callable.py

Copy file name to clipboardExpand all lines: examples/model_selection/plot_grid_search_refit_callable.py
+280-27Lines changed: 280 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,54 @@
33
Balance model complexity and cross-validated score
44
==================================================
55
6-
This example balances model complexity and cross-validated score by
7-
finding a decent accuracy within 1 standard deviation of the best accuracy
8-
score while minimising the number of PCA components [1].
6+
This example demonstrates how to balance model complexity and cross-validated score by
7+
finding a decent accuracy within 1 standard deviation of the best accuracy score while
8+
minimising the number of :class:`~sklearn.decomposition.PCA` components [1]. It uses
9+
:class:`~sklearn.model_selection.GridSearchCV` with a custom refit callable to select
10+
the optimal model.
911
1012
The figure shows the trade-off between cross-validated score and the number
11-
of PCA components. The balanced case is when n_components=10 and accuracy=0.88,
13+
of PCA components. The balanced case is when `n_components=10` and `accuracy=0.88`,
1214
which falls into the range within 1 standard deviation of the best accuracy
1315
score.
1416
1517
[1] Hastie, T., Tibshirani, R.,, Friedman, J. (2001). Model Assessment and
1618
Selection. The Elements of Statistical Learning (pp. 219-260). New York,
1719
NY, USA: Springer New York Inc..
18-
1920
"""
2021

2122
# Authors: The scikit-learn developers
2223
# SPDX-License-Identifier: BSD-3-Clause
2324

2425
import matplotlib.pyplot as plt
2526
import numpy as np
27+
import polars as pl
2628

2729
from sklearn.datasets import load_digits
2830
from sklearn.decomposition import PCA
29-
from sklearn.model_selection import GridSearchCV
31+
from sklearn.linear_model import LogisticRegression
32+
from sklearn.model_selection import GridSearchCV, ShuffleSplit
3033
from sklearn.pipeline import Pipeline
31-
from sklearn.svm import LinearSVC
34+
35+
# %%
36+
# Introduction
37+
# ------------
38+
#
39+
# When tuning hyperparameters, we often want to balance model complexity and
40+
# performance. The "one-standard-error" rule is a common approach: select the simplest
41+
# model whose performance is within one standard error of the best model's performance.
42+
# This helps to avoid overfitting by preferring simpler models when their performance is
43+
# statistically comparable to more complex ones.
44+
45+
# %%
46+
# Helper functions
47+
# ----------------
48+
#
49+
# We define two helper functions:
50+
# 1. `lower_bound`: Calculates the threshold for acceptable performance
51+
# (best score - 1 std)
52+
# 2. `best_low_complexity`: Selects the model with the fewest PCA components that
53+
# exceeds this threshold
3254

3355

3456
def lower_bound(cv_results):
@@ -79,49 +101,280 @@ def best_low_complexity(cv_results):
79101
return best_idx
80102

81103

104+
# %%
105+
# Set up the pipeline and parameter grid
106+
# --------------------------------------
107+
#
108+
# We create a pipeline with two steps:
109+
# 1. Dimensionality reduction using PCA
110+
# 2. Classification using LogisticRegression
111+
#
112+
# We'll search over different numbers of PCA components to find the optimal complexity.
113+
82114
pipe = Pipeline(
83115
[
84116
("reduce_dim", PCA(random_state=42)),
85-
("classify", LinearSVC(random_state=42, C=0.01)),
117+
("classify", LogisticRegression(random_state=42, C=0.01, max_iter=1000)),
86118
]
87119
)
88120

89-
param_grid = {"reduce_dim__n_components": [6, 8, 10, 12, 14]}
121+
param_grid = {"reduce_dim__n_components": [6, 8, 10, 15, 20, 25, 35, 45, 55]}
122+
123+
# %%
124+
# Perform the search with GridSearchCV
125+
# ------------------------------------
126+
#
127+
# We use `GridSearchCV` with our custom `best_low_complexity` function as the refit
128+
# parameter. This function will select the model with the fewest PCA components that
129+
# still performs within one standard deviation of the best model.
90130

91131
grid = GridSearchCV(
92132
pipe,
93-
cv=10,
94-
n_jobs=1,
133+
# Use a non-stratified CV strategy to make sure that the inter-fold
134+
# standard deviation of the test scores is informative.
135+
cv=ShuffleSplit(n_splits=30, random_state=0),
136+
n_jobs=1, # increase this on your machine to use more physical cores
95137
param_grid=param_grid,
96138
scoring="accuracy",
97139
refit=best_low_complexity,
140+
return_train_score=True,
98141
)
142+
143+
# %%
144+
# Load the digits dataset and fit the model
145+
# -----------------------------------------
146+
99147
X, y = load_digits(return_X_y=True)
100148
grid.fit(X, y)
101149

150+
# %%
151+
# Visualize the results
152+
# ---------------------
153+
#
154+
# We'll create a bar chart showing the test scores for different numbers of PCA
155+
# components, along with horizontal lines indicating the best score and the
156+
# one-standard-deviation threshold.
157+
102158
n_components = grid.cv_results_["param_reduce_dim__n_components"]
103159
test_scores = grid.cv_results_["mean_test_score"]
104160

105-
plt.figure()
106-
plt.bar(n_components, test_scores, width=1.3, color="b")
161+
# Create a polars DataFrame for better data manipulation and visualization
162+
results_df = pl.DataFrame(
163+
{
164+
"n_components": n_components,
165+
"mean_test_score": test_scores,
166+
"std_test_score": grid.cv_results_["std_test_score"],
167+
"mean_train_score": grid.cv_results_["mean_train_score"],
168+
"std_train_score": grid.cv_results_["std_train_score"],
169+
"mean_fit_time": grid.cv_results_["mean_fit_time"],
170+
"rank_test_score": grid.cv_results_["rank_test_score"],
171+
}
172+
)
107173

108-
lower = lower_bound(grid.cv_results_)
109-
plt.axhline(np.max(test_scores), linestyle="--", color="y", label="Best score")
110-
plt.axhline(lower, linestyle="--", color=".5", label="Best score - 1 std")
174+
# Sort by number of components
175+
results_df = results_df.sort("n_components")
111176

112-
plt.title("Balance model complexity and cross-validated score")
113-
plt.xlabel("Number of PCA components used")
114-
plt.ylabel("Digit classification accuracy")
115-
plt.xticks(n_components.tolist())
116-
plt.ylim((0, 1.0))
117-
plt.legend(loc="upper left")
177+
# Calculate the lower bound threshold
178+
lower = lower_bound(grid.cv_results_)
118179

180+
# Get the best model information
119181
best_index_ = grid.best_index_
182+
best_components = n_components[best_index_]
183+
best_score = grid.cv_results_["mean_test_score"][best_index_]
184+
185+
# Add a column to mark the selected model
186+
results_df = results_df.with_columns(
187+
pl.when(pl.col("n_components") == best_components)
188+
.then(pl.lit("Selected"))
189+
.otherwise(pl.lit("Regular"))
190+
.alias("model_type")
191+
)
192+
193+
# Get the number of CV splits from the results
194+
n_splits = sum(
195+
1
196+
for key in grid.cv_results_.keys()
197+
if key.startswith("split") and key.endswith("test_score")
198+
)
199+
200+
# Extract individual scores for each split
201+
test_scores = np.array(
202+
[
203+
[grid.cv_results_[f"split{i}_test_score"][j] for i in range(n_splits)]
204+
for j in range(len(n_components))
205+
]
206+
)
207+
train_scores = np.array(
208+
[
209+
[grid.cv_results_[f"split{i}_train_score"][j] for i in range(n_splits)]
210+
for j in range(len(n_components))
211+
]
212+
)
213+
214+
# Calculate mean and std of test scores
215+
mean_test_scores = np.mean(test_scores, axis=1)
216+
std_test_scores = np.std(test_scores, axis=1)
217+
218+
# Find best score and threshold
219+
best_mean_score = np.max(mean_test_scores)
220+
threshold = best_mean_score - std_test_scores[np.argmax(mean_test_scores)]
221+
222+
# Create a single figure for visualization
223+
fig, ax = plt.subplots(figsize=(12, 8))
120224

121-
print("The best_index_ is %d" % best_index_)
122-
print("The n_components selected is %d" % n_components[best_index_])
123-
print(
124-
"The corresponding accuracy score is %.2f"
125-
% grid.cv_results_["mean_test_score"][best_index_]
225+
# Plot individual points
226+
for i, comp in enumerate(n_components):
227+
# Plot individual test points
228+
plt.scatter(
229+
[comp] * n_splits,
230+
test_scores[i],
231+
alpha=0.2,
232+
color="blue",
233+
s=20,
234+
label="Individual test scores" if i == 0 else "",
235+
)
236+
# Plot individual train points
237+
plt.scatter(
238+
[comp] * n_splits,
239+
train_scores[i],
240+
alpha=0.2,
241+
color="green",
242+
s=20,
243+
label="Individual train scores" if i == 0 else "",
244+
)
245+
246+
# Plot mean lines with error bands
247+
plt.plot(
248+
n_components,
249+
np.mean(test_scores, axis=1),
250+
"-",
251+
color="blue",
252+
linewidth=2,
253+
label="Mean test score",
254+
)
255+
plt.fill_between(
256+
n_components,
257+
np.mean(test_scores, axis=1) - np.std(test_scores, axis=1),
258+
np.mean(test_scores, axis=1) + np.std(test_scores, axis=1),
259+
alpha=0.15,
260+
color="blue",
261+
)
262+
263+
plt.plot(
264+
n_components,
265+
np.mean(train_scores, axis=1),
266+
"-",
267+
color="green",
268+
linewidth=2,
269+
label="Mean train score",
270+
)
271+
plt.fill_between(
272+
n_components,
273+
np.mean(train_scores, axis=1) - np.std(train_scores, axis=1),
274+
np.mean(train_scores, axis=1) + np.std(train_scores, axis=1),
275+
alpha=0.15,
276+
color="green",
126277
)
278+
279+
# Add threshold lines
280+
plt.axhline(
281+
best_mean_score,
282+
color="#9b59b6", # Purple
283+
linestyle="--",
284+
label="Best score",
285+
linewidth=2,
286+
)
287+
plt.axhline(
288+
threshold,
289+
color="#e67e22", # Orange
290+
linestyle="--",
291+
label="Best score - 1 std",
292+
linewidth=2,
293+
)
294+
295+
# Highlight selected model
296+
plt.axvline(
297+
best_components,
298+
color="#9b59b6", # Purple
299+
alpha=0.2,
300+
linewidth=8,
301+
label="Selected model",
302+
)
303+
304+
# Set titles and labels
305+
plt.xlabel("Number of PCA components", fontsize=12)
306+
plt.ylabel("Score", fontsize=12)
307+
plt.title("Model Selection: Balancing Complexity and Performance", fontsize=14)
308+
plt.grid(True, linestyle="--", alpha=0.7)
309+
plt.legend(
310+
bbox_to_anchor=(1.02, 1),
311+
loc="upper left",
312+
borderaxespad=0,
313+
)
314+
315+
# Set axis properties
316+
plt.xticks(n_components)
317+
plt.ylim((0.85, 1.0))
318+
319+
# # Adjust layout
320+
plt.tight_layout()
321+
322+
# %%
323+
# Print the results
324+
# -----------------
325+
#
326+
# We print information about the selected model, including its complexity and
327+
# performance. We also show a summary table of all models using polars.
328+
329+
print("Best model selected by the one-standard-error rule:")
330+
print(f"Number of PCA components: {best_components}")
331+
print(f"Accuracy score: {best_score:.4f}")
332+
print(f"Best possible accuracy: {np.max(test_scores):.4f}")
333+
print(f"Accuracy threshold (best - 1 std): {lower:.4f}")
334+
335+
# Create a summary table with polars
336+
summary_df = results_df.select(
337+
pl.col("n_components"),
338+
pl.col("mean_test_score").round(4).alias("test_score"),
339+
pl.col("std_test_score").round(4).alias("test_std"),
340+
pl.col("mean_train_score").round(4).alias("train_score"),
341+
pl.col("std_train_score").round(4).alias("train_std"),
342+
pl.col("mean_fit_time").round(3).alias("fit_time"),
343+
pl.col("rank_test_score").alias("rank"),
344+
)
345+
346+
# Add a column to mark the selected model
347+
summary_df = summary_df.with_columns(
348+
pl.when(pl.col("n_components") == best_components)
349+
.then(pl.lit("*"))
350+
.otherwise(pl.lit(""))
351+
.alias("selected")
352+
)
353+
354+
print("\nModel comparison table:")
355+
print(summary_df)
356+
357+
# %%
358+
# Conclusion
359+
# ----------
360+
#
361+
# The one-standard-error rule helps us select a simpler model (fewer PCA components)
362+
# while maintaining performance statistically comparable to the best model.
363+
# This approach can help prevent overfitting and improve model interpretability
364+
# and efficiency.
365+
#
366+
# In this example, we've seen how to implement this rule using a custom refit
367+
# callable with :class:`~sklearn.model_selection.GridSearchCV`.
368+
#
369+
# Key takeaways:
370+
# 1. The one-standard-error rule provides a good rule of thumb to select simpler models
371+
# 2. Custom refit callables in :class:`~sklearn.model_selection.GridSearchCV` allow for
372+
# flexible model selection strategies
373+
# 3. Visualizing both train and test scores helps identify potential overfitting
374+
#
375+
# This approach can be applied to other model selection scenarios where balancing
376+
# complexity and performance is important, or in cases where a use-case specific
377+
# selection of the "best" model is desired.
378+
379+
# Display the figure
127380
plt.show()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.