-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Stratified Group KFold implementation #18649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stratified Group KFold implementation #18649
Conversation
…ratified-groupshufflesplit
Parameters are the same as for StratifiedKFold to ensure similar behavior given n_groups == n_samples
The idea is to ensure similar behavior when groups are trivial (n_groups == n_samples)
Required to produce balanced size folds when the distribution of y is more or less the same
Thank you so much for a thorough review, @jnothman! The issues you've outlined seem to be fixed now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also update https://scikit-learn.org/stable/auto_examples/model_selection/plot_cv_indices.html to share a visual intuition of what's going on?
thx a lot @marrodion
sklearn/model_selection/_split.py
Outdated
The implementation is designed to: | ||
|
||
* Mimic the behavior of StratifiedKFold as much as possible for trivial | ||
groups (e.g. when each group contain only one sample). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
groups (e.g. when each group contain only one sample). | |
groups (e.g. when each group contains only one sample). |
cf58f75
to
6c8d5da
Compare
@agramfort Thank you for your comment, I've updated the visualization, however, the data in the example makes current algorithm perform identical to |
yes I would add something to the example so we can visually explain what it does
|
I've added it as an example comparing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx @marrodion !
this isn't rendering right: images are in the wrong sections. https://134931-843222-gh.circle-artifacts.com/0/doc/modules/cross_validation.html |
I think the change to the example affected the numbering of images exported and used in the user guide. |
The image is actually a very helpful demonstration of how it differs. Thanks both! |
Seems to be fixed now: https://134972-843222-gh.circle-artifacts.com/0/doc/modules/cross_validation.html |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice work. I love the tests. Well done.
- This split is suboptimal in a sense that it might produce imbalanced splits | ||
even if perfect stratification is possible. If you have relatively close | ||
distribution of classes in each group, using :class:`GroupKFold` is better. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Insert the visualisation here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thank you.
Wasn't sure if needed, not every CV has a visualization in this documentation page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps not needed, but helpful. Happy for you to make the docs more consistent in another pr! ;)
# Check that stratified kfold preserves class ratios in individual splits | ||
# Repeat with shuffling turned off and on | ||
n_samples = 1000 | ||
X = np.ones(n_samples) | ||
y = np.array([4] * int(0.10 * n_samples) + | ||
[0] * int(0.89 * n_samples) + | ||
[1] * int(0.01 * n_samples)) | ||
groups = np.arange(len(y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
groups = np.arange(len(y)) | |
groups = np.arange(len(y)) # ensure perfect stratification with StratifiedGroupKFold |
# Check that stratified kfold gives the same indices regardless of labels | ||
n_samples = 100 | ||
y = np.array([2] * int(0.10 * n_samples) + | ||
[0] * int(0.89 * n_samples) + | ||
[1] * int(0.01 * n_samples)) | ||
X = np.ones(len(y)) | ||
groups = np.arange(len(y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
groups = np.arange(len(y)) | |
groups = np.arange(len(y)) # ensure perfect stratification with StratifiedGroupKFold |
# Check that KFold returns folds with balanced sizes (only when | ||
# stratification is possible) | ||
# Repeat with shuffling turned off and on | ||
X = np.ones(17) | ||
y = [0] * 3 + [1] * 14 | ||
groups = np.arange(len(y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
groups = np.arange(len(y)) | |
groups = np.arange(len(y)) # ensure perfect stratification with StratifiedGroupKFold |
879d947
to
42a6b80
Compare
Thank you @marrodion @hermidalc! It's nice to finally have a solution for this case!! |
Wow, it's great to see my kernel could be helpful, and I'm happy that the method is finally in scikit-learn. Thanks @marrodion and @hermidalc for doing this 🍻 |
Reference Issues/PRs
Fixes #13621.
What does this implement/fix? Explain your changes.
Implementation of StratifiedGroupKFold based on the PR #15239 and kaggle kernel.
Implementation considers distribution of labels within the groups without the restriction of each group having only one class.
For trivial cases:
StratifiedKFold
as much as possibleGroupKFold
Issues with current implementation:
.sklearn/model_selection/tests/test_split.py:559
-test_startified_group_kfold_approximate
)Any other comments?
Given outlined restrictions, I am hesitant whether this should be included into scikit-learn or not. It seems that its usefulness is limited. However, there seem to be a constant interest in this feature and I failed to design a better solution that is not brute-force. Would like to hear any thoughts/ideas on which algorithm might produce better results.