diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py index e6c3580c787f0..d456546891069 100644 --- a/examples/model_selection/plot_cv_indices.py +++ b/examples/model_selection/plot_cv_indices.py @@ -99,9 +99,10 @@ def visualize_groups(classes, groups, name): def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): """Create a sample plot for indices of a cross-validation object.""" - + use_groups = "Group" in type(cv).__name__ + groups = group if use_groups else None # Generate the training/testing visualizations for each CV split - for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)): + for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=groups)): # Fill in indices with the training/test groups indices = np.array([np.nan] * len(X)) indices[tt] = 1