Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

feat: add stratify param support to ml.model_selection.train_test_split method #815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions 40 bigframes/ml/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


import typing
from typing import List, Union
from typing import cast, List, Union

from bigframes.ml import utils
import bigframes.pandas as bpd
Expand All @@ -29,6 +29,7 @@ def train_test_split(
test_size: Union[float, None] = None,
train_size: Union[float, None] = None,
random_state: Union[int, None] = None,
stratify: Union[bpd.Series, None] = None,
) -> List[Union[bpd.DataFrame, bpd.Series]]:
"""Splits dataframes or series into random train and test subsets.

Expand All @@ -46,6 +47,10 @@ def train_test_split(
random_state (default None):
A seed to use for randomly choosing the rows of the split. If not
set, a random split will be generated each time.
stratify: (bigframes.series.Series or None, default None):
If not None, data is split in a stratified fashion, using this as the class labels. Each split has the same distribution of the class labels with the original dataset.
Default to None.
Note: By setting the stratify parameter, the memory consumption and generated SQL will be linear to the unique values in the Series. May return errors if the unique values size is too large.

Returns:
List[Union[bigframes.dataframe.DataFrame, bigframes.series.Series]]: A list of BigQuery DataFrames or Series.
Expand Down Expand Up @@ -76,7 +81,38 @@ def train_test_split(

dfs = list(utils.convert_to_dataframe(*arrays))

split_dfs = dfs[0]._split(fracs=(train_size, test_size), random_state=random_state)
def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFrame]:
"""Split a single DF accoding to the stratify Series."""
stratify = stratify.rename("bigframes_stratify_col") # avoid name conflicts
merged_df = df.join(stratify.to_frame(), how="outer")

train_dfs, test_dfs = [], []
uniq = stratify.unique()
for value in uniq:
cur = merged_df[merged_df["bigframes_stratify_col"] == value]
train, test = train_test_split(
cur,
test_size=test_size,
train_size=train_size,
random_state=random_state,
)
train_dfs.append(train)
test_dfs.append(test)

train_df = cast(
bpd.DataFrame, bpd.concat(train_dfs).drop(columns="bigframes_stratify_col")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

concat complexity was much discussed in various threads (like RAG notebook, google next demo), e.g. https://screenshot.googleplex.com/7jehfHCAVJrmc3p. The number of unique values in the stratify col could be large to run into that. It would be a good idea to test and document where the limit lies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Concat will grow BFET and SQL size. It may encounter SQL size or OOM errors when unique values size is too large. Added a note.

)
test_df = cast(
bpd.DataFrame, bpd.concat(test_dfs).drop(columns="bigframes_stratify_col")
)
return [train_df, test_df]

if stratify is None:
split_dfs = dfs[0]._split(
fracs=(train_size, test_size), random_state=random_state
)
else:
split_dfs = _stratify_split(dfs[0], stratify)
train_index = split_dfs[0].index
test_index = split_dfs[1].index

Expand Down
62 changes: 62 additions & 0 deletions 62 tests/system/small/ml/test_model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,65 @@ def test_train_test_split_value_error(penguins_df_default_index, train_size, tes
model_selection.train_test_split(
X, y, train_size=train_size, test_size=test_size
)


def test_train_test_split_stratify(penguins_df_default_index):
X = penguins_df_default_index[
[
"species",
"island",
"culmen_length_mm",
]
]
y = penguins_df_default_index[["species"]]
X_train, X_test, y_train, y_test = model_selection.train_test_split(
X, y, stratify=penguins_df_default_index["species"]
)

# Original distribution is [152, 124, 68]. All the categories follow 75/25 split
train_counts = pd.Series(
[114, 93, 51],
index=pd.Index(
[
"Adelie Penguin (Pygoscelis adeliae)",
"Gentoo penguin (Pygoscelis papua)",
"Chinstrap penguin (Pygoscelis antarctica)",
],
name="species",
),
dtype="Int64",
name="count",
)
test_counts = pd.Series(
[38, 31, 17],
index=pd.Index(
[
"Adelie Penguin (Pygoscelis adeliae)",
"Gentoo penguin (Pygoscelis papua)",
"Chinstrap penguin (Pygoscelis antarctica)",
],
name="species",
),
dtype="Int64",
name="count",
)
pd.testing.assert_series_equal(
X_train["species"].value_counts().to_pandas(),
train_counts,
check_index_type=False,
)
pd.testing.assert_series_equal(
X_test["species"].value_counts().to_pandas(),
test_counts,
check_index_type=False,
)
pd.testing.assert_series_equal(
y_train["species"].value_counts().to_pandas(),
train_counts,
check_index_type=False,
)
pd.testing.assert_series_equal(
y_test["species"].value_counts().to_pandas(),
test_counts,
check_index_type=False,
)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.