3
3
Balance model complexity and cross-validated score
4
4
==================================================
5
5
6
- This example balances model complexity and cross-validated score by
7
- finding a decent accuracy within 1 standard deviation of the best accuracy
8
- score while minimising the number of PCA components [1].
6
+ This example demonstrates how to balance model complexity and cross-validated score by
7
+ finding a decent accuracy within 1 standard deviation of the best accuracy score while
8
+ minimising the number of :class:`~sklearn.decomposition.PCA` components [1]. It uses
9
+ :class:`~sklearn.model_selection.GridSearchCV` with a custom refit callable to select
10
+ the optimal model.
9
11
10
12
The figure shows the trade-off between cross-validated score and the number
11
- of PCA components. The balanced case is when n_components=10 and accuracy=0.88,
13
+ of PCA components. The balanced case is when ` n_components=10` and ` accuracy=0.88` ,
12
14
which falls into the range within 1 standard deviation of the best accuracy
13
15
score.
14
16
15
17
[1] Hastie, T., Tibshirani, R.,, Friedman, J. (2001). Model Assessment and
16
18
Selection. The Elements of Statistical Learning (pp. 219-260). New York,
17
19
NY, USA: Springer New York Inc..
18
-
19
20
"""
20
21
21
22
# Authors: The scikit-learn developers
22
23
# SPDX-License-Identifier: BSD-3-Clause
23
24
24
25
import matplotlib .pyplot as plt
25
26
import numpy as np
27
+ import polars as pl
26
28
27
29
from sklearn .datasets import load_digits
28
30
from sklearn .decomposition import PCA
29
- from sklearn .model_selection import GridSearchCV
31
+ from sklearn .linear_model import LogisticRegression
32
+ from sklearn .model_selection import GridSearchCV , ShuffleSplit
30
33
from sklearn .pipeline import Pipeline
31
- from sklearn .svm import LinearSVC
34
+
35
+ # %%
36
+ # Introduction
37
+ # ------------
38
+ #
39
+ # When tuning hyperparameters, we often want to balance model complexity and
40
+ # performance. The "one-standard-error" rule is a common approach: select the simplest
41
+ # model whose performance is within one standard error of the best model's performance.
42
+ # This helps to avoid overfitting by preferring simpler models when their performance is
43
+ # statistically comparable to more complex ones.
44
+
45
+ # %%
46
+ # Helper functions
47
+ # ----------------
48
+ #
49
+ # We define two helper functions:
50
+ # 1. `lower_bound`: Calculates the threshold for acceptable performance
51
+ # (best score - 1 std)
52
+ # 2. `best_low_complexity`: Selects the model with the fewest PCA components that
53
+ # exceeds this threshold
32
54
33
55
34
56
def lower_bound (cv_results ):
@@ -79,49 +101,280 @@ def best_low_complexity(cv_results):
79
101
return best_idx
80
102
81
103
104
+ # %%
105
+ # Set up the pipeline and parameter grid
106
+ # --------------------------------------
107
+ #
108
+ # We create a pipeline with two steps:
109
+ # 1. Dimensionality reduction using PCA
110
+ # 2. Classification using LogisticRegression
111
+ #
112
+ # We'll search over different numbers of PCA components to find the optimal complexity.
113
+
82
114
pipe = Pipeline (
83
115
[
84
116
("reduce_dim" , PCA (random_state = 42 )),
85
- ("classify" , LinearSVC (random_state = 42 , C = 0.01 )),
117
+ ("classify" , LogisticRegression (random_state = 42 , C = 0.01 , max_iter = 1000 )),
86
118
]
87
119
)
88
120
89
- param_grid = {"reduce_dim__n_components" : [6 , 8 , 10 , 12 , 14 ]}
121
+ param_grid = {"reduce_dim__n_components" : [6 , 8 , 10 , 15 , 20 , 25 , 35 , 45 , 55 ]}
122
+
123
+ # %%
124
+ # Perform the search with GridSearchCV
125
+ # ------------------------------------
126
+ #
127
+ # We use `GridSearchCV` with our custom `best_low_complexity` function as the refit
128
+ # parameter. This function will select the model with the fewest PCA components that
129
+ # still performs within one standard deviation of the best model.
90
130
91
131
grid = GridSearchCV (
92
132
pipe ,
93
- cv = 10 ,
94
- n_jobs = 1 ,
133
+ # Use a non-stratified CV strategy to make sure that the inter-fold
134
+ # standard deviation of the test scores is informative.
135
+ cv = ShuffleSplit (n_splits = 30 , random_state = 0 ),
136
+ n_jobs = 1 , # increase this on your machine to use more physical cores
95
137
param_grid = param_grid ,
96
138
scoring = "accuracy" ,
97
139
refit = best_low_complexity ,
140
+ return_train_score = True ,
98
141
)
142
+
143
+ # %%
144
+ # Load the digits dataset and fit the model
145
+ # -----------------------------------------
146
+
99
147
X , y = load_digits (return_X_y = True )
100
148
grid .fit (X , y )
101
149
150
+ # %%
151
+ # Visualize the results
152
+ # ---------------------
153
+ #
154
+ # We'll create a bar chart showing the test scores for different numbers of PCA
155
+ # components, along with horizontal lines indicating the best score and the
156
+ # one-standard-deviation threshold.
157
+
102
158
n_components = grid .cv_results_ ["param_reduce_dim__n_components" ]
103
159
test_scores = grid .cv_results_ ["mean_test_score" ]
104
160
105
- plt .figure ()
106
- plt .bar (n_components , test_scores , width = 1.3 , color = "b" )
161
+ # Create a polars DataFrame for better data manipulation and visualization
162
+ results_df = pl .DataFrame (
163
+ {
164
+ "n_components" : n_components ,
165
+ "mean_test_score" : test_scores ,
166
+ "std_test_score" : grid .cv_results_ ["std_test_score" ],
167
+ "mean_train_score" : grid .cv_results_ ["mean_train_score" ],
168
+ "std_train_score" : grid .cv_results_ ["std_train_score" ],
169
+ "mean_fit_time" : grid .cv_results_ ["mean_fit_time" ],
170
+ "rank_test_score" : grid .cv_results_ ["rank_test_score" ],
171
+ }
172
+ )
107
173
108
- lower = lower_bound (grid .cv_results_ )
109
- plt .axhline (np .max (test_scores ), linestyle = "--" , color = "y" , label = "Best score" )
110
- plt .axhline (lower , linestyle = "--" , color = ".5" , label = "Best score - 1 std" )
174
+ # Sort by number of components
175
+ results_df = results_df .sort ("n_components" )
111
176
112
- plt .title ("Balance model complexity and cross-validated score" )
113
- plt .xlabel ("Number of PCA components used" )
114
- plt .ylabel ("Digit classification accuracy" )
115
- plt .xticks (n_components .tolist ())
116
- plt .ylim ((0 , 1.0 ))
117
- plt .legend (loc = "upper left" )
177
+ # Calculate the lower bound threshold
178
+ lower = lower_bound (grid .cv_results_ )
118
179
180
+ # Get the best model information
119
181
best_index_ = grid .best_index_
182
+ best_components = n_components [best_index_ ]
183
+ best_score = grid .cv_results_ ["mean_test_score" ][best_index_ ]
184
+
185
+ # Add a column to mark the selected model
186
+ results_df = results_df .with_columns (
187
+ pl .when (pl .col ("n_components" ) == best_components )
188
+ .then (pl .lit ("Selected" ))
189
+ .otherwise (pl .lit ("Regular" ))
190
+ .alias ("model_type" )
191
+ )
192
+
193
+ # Get the number of CV splits from the results
194
+ n_splits = sum (
195
+ 1
196
+ for key in grid .cv_results_ .keys ()
197
+ if key .startswith ("split" ) and key .endswith ("test_score" )
198
+ )
199
+
200
+ # Extract individual scores for each split
201
+ test_scores = np .array (
202
+ [
203
+ [grid .cv_results_ [f"split{ i } _test_score" ][j ] for i in range (n_splits )]
204
+ for j in range (len (n_components ))
205
+ ]
206
+ )
207
+ train_scores = np .array (
208
+ [
209
+ [grid .cv_results_ [f"split{ i } _train_score" ][j ] for i in range (n_splits )]
210
+ for j in range (len (n_components ))
211
+ ]
212
+ )
213
+
214
+ # Calculate mean and std of test scores
215
+ mean_test_scores = np .mean (test_scores , axis = 1 )
216
+ std_test_scores = np .std (test_scores , axis = 1 )
217
+
218
+ # Find best score and threshold
219
+ best_mean_score = np .max (mean_test_scores )
220
+ threshold = best_mean_score - std_test_scores [np .argmax (mean_test_scores )]
221
+
222
+ # Create a single figure for visualization
223
+ fig , ax = plt .subplots (figsize = (12 , 8 ))
120
224
121
- print ("The best_index_ is %d" % best_index_ )
122
- print ("The n_components selected is %d" % n_components [best_index_ ])
123
- print (
124
- "The corresponding accuracy score is %.2f"
125
- % grid .cv_results_ ["mean_test_score" ][best_index_ ]
225
+ # Plot individual points
226
+ for i , comp in enumerate (n_components ):
227
+ # Plot individual test points
228
+ plt .scatter (
229
+ [comp ] * n_splits ,
230
+ test_scores [i ],
231
+ alpha = 0.2 ,
232
+ color = "blue" ,
233
+ s = 20 ,
234
+ label = "Individual test scores" if i == 0 else "" ,
235
+ )
236
+ # Plot individual train points
237
+ plt .scatter (
238
+ [comp ] * n_splits ,
239
+ train_scores [i ],
240
+ alpha = 0.2 ,
241
+ color = "green" ,
242
+ s = 20 ,
243
+ label = "Individual train scores" if i == 0 else "" ,
244
+ )
245
+
246
+ # Plot mean lines with error bands
247
+ plt .plot (
248
+ n_components ,
249
+ np .mean (test_scores , axis = 1 ),
250
+ "-" ,
251
+ color = "blue" ,
252
+ linewidth = 2 ,
253
+ label = "Mean test score" ,
254
+ )
255
+ plt .fill_between (
256
+ n_components ,
257
+ np .mean (test_scores , axis = 1 ) - np .std (test_scores , axis = 1 ),
258
+ np .mean (test_scores , axis = 1 ) + np .std (test_scores , axis = 1 ),
259
+ alpha = 0.15 ,
260
+ color = "blue" ,
261
+ )
262
+
263
+ plt .plot (
264
+ n_components ,
265
+ np .mean (train_scores , axis = 1 ),
266
+ "-" ,
267
+ color = "green" ,
268
+ linewidth = 2 ,
269
+ label = "Mean train score" ,
270
+ )
271
+ plt .fill_between (
272
+ n_components ,
273
+ np .mean (train_scores , axis = 1 ) - np .std (train_scores , axis = 1 ),
274
+ np .mean (train_scores , axis = 1 ) + np .std (train_scores , axis = 1 ),
275
+ alpha = 0.15 ,
276
+ color = "green" ,
126
277
)
278
+
279
+ # Add threshold lines
280
+ plt .axhline (
281
+ best_mean_score ,
282
+ color = "#9b59b6" , # Purple
283
+ linestyle = "--" ,
284
+ label = "Best score" ,
285
+ linewidth = 2 ,
286
+ )
287
+ plt .axhline (
288
+ threshold ,
289
+ color = "#e67e22" , # Orange
290
+ linestyle = "--" ,
291
+ label = "Best score - 1 std" ,
292
+ linewidth = 2 ,
293
+ )
294
+
295
+ # Highlight selected model
296
+ plt .axvline (
297
+ best_components ,
298
+ color = "#9b59b6" , # Purple
299
+ alpha = 0.2 ,
300
+ linewidth = 8 ,
301
+ label = "Selected model" ,
302
+ )
303
+
304
+ # Set titles and labels
305
+ plt .xlabel ("Number of PCA components" , fontsize = 12 )
306
+ plt .ylabel ("Score" , fontsize = 12 )
307
+ plt .title ("Model Selection: Balancing Complexity and Performance" , fontsize = 14 )
308
+ plt .grid (True , linestyle = "--" , alpha = 0.7 )
309
+ plt .legend (
310
+ bbox_to_anchor = (1.02 , 1 ),
311
+ loc = "upper left" ,
312
+ borderaxespad = 0 ,
313
+ )
314
+
315
+ # Set axis properties
316
+ plt .xticks (n_components )
317
+ plt .ylim ((0.85 , 1.0 ))
318
+
319
+ # # Adjust layout
320
+ plt .tight_layout ()
321
+
322
+ # %%
323
+ # Print the results
324
+ # -----------------
325
+ #
326
+ # We print information about the selected model, including its complexity and
327
+ # performance. We also show a summary table of all models using polars.
328
+
329
+ print ("Best model selected by the one-standard-error rule:" )
330
+ print (f"Number of PCA components: { best_components } " )
331
+ print (f"Accuracy score: { best_score :.4f} " )
332
+ print (f"Best possible accuracy: { np .max (test_scores ):.4f} " )
333
+ print (f"Accuracy threshold (best - 1 std): { lower :.4f} " )
334
+
335
+ # Create a summary table with polars
336
+ summary_df = results_df .select (
337
+ pl .col ("n_components" ),
338
+ pl .col ("mean_test_score" ).round (4 ).alias ("test_score" ),
339
+ pl .col ("std_test_score" ).round (4 ).alias ("test_std" ),
340
+ pl .col ("mean_train_score" ).round (4 ).alias ("train_score" ),
341
+ pl .col ("std_train_score" ).round (4 ).alias ("train_std" ),
342
+ pl .col ("mean_fit_time" ).round (3 ).alias ("fit_time" ),
343
+ pl .col ("rank_test_score" ).alias ("rank" ),
344
+ )
345
+
346
+ # Add a column to mark the selected model
347
+ summary_df = summary_df .with_columns (
348
+ pl .when (pl .col ("n_components" ) == best_components )
349
+ .then (pl .lit ("*" ))
350
+ .otherwise (pl .lit ("" ))
351
+ .alias ("selected" )
352
+ )
353
+
354
+ print ("\n Model comparison table:" )
355
+ print (summary_df )
356
+
357
+ # %%
358
+ # Conclusion
359
+ # ----------
360
+ #
361
+ # The one-standard-error rule helps us select a simpler model (fewer PCA components)
362
+ # while maintaining performance statistically comparable to the best model.
363
+ # This approach can help prevent overfitting and improve model interpretability
364
+ # and efficiency.
365
+ #
366
+ # In this example, we've seen how to implement this rule using a custom refit
367
+ # callable with :class:`~sklearn.model_selection.GridSearchCV`.
368
+ #
369
+ # Key takeaways:
370
+ # 1. The one-standard-error rule provides a good rule of thumb to select simpler models
371
+ # 2. Custom refit callables in :class:`~sklearn.model_selection.GridSearchCV` allow for
372
+ # flexible model selection strategies
373
+ # 3. Visualizing both train and test scores helps identify potential overfitting
374
+ #
375
+ # This approach can be applied to other model selection scenarios where balancing
376
+ # complexity and performance is important, or in cases where a use-case specific
377
+ # selection of the "best" model is desired.
378
+
379
+ # Display the figure
127
380
plt .show ()
0 commit comments