diff --git a/bigframes/ml/model_selection.py b/bigframes/ml/model_selection.py index 48eb5a93a7..6220e899ae 100644 --- a/bigframes/ml/model_selection.py +++ b/bigframes/ml/model_selection.py @@ -18,7 +18,7 @@ import typing -from typing import List, Union +from typing import cast, List, Union from bigframes.ml import utils import bigframes.pandas as bpd @@ -29,6 +29,7 @@ def train_test_split( test_size: Union[float, None] = None, train_size: Union[float, None] = None, random_state: Union[int, None] = None, + stratify: Union[bpd.Series, None] = None, ) -> List[Union[bpd.DataFrame, bpd.Series]]: """Splits dataframes or series into random train and test subsets. @@ -46,6 +47,10 @@ def train_test_split( random_state (default None): A seed to use for randomly choosing the rows of the split. If not set, a random split will be generated each time. + stratify: (bigframes.series.Series or None, default None): + If not None, data is split in a stratified fashion, using this as the class labels. Each split has the same distribution of the class labels with the original dataset. + Default to None. + Note: By setting the stratify parameter, the memory consumption and generated SQL will be linear to the unique values in the Series. May return errors if the unique values size is too large. Returns: List[Union[bigframes.dataframe.DataFrame, bigframes.series.Series]]: A list of BigQuery DataFrames or Series. @@ -76,7 +81,38 @@ def train_test_split( dfs = list(utils.convert_to_dataframe(*arrays)) - split_dfs = dfs[0]._split(fracs=(train_size, test_size), random_state=random_state) + def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFrame]: + """Split a single DF accoding to the stratify Series.""" + stratify = stratify.rename("bigframes_stratify_col") # avoid name conflicts + merged_df = df.join(stratify.to_frame(), how="outer") + + train_dfs, test_dfs = [], [] + uniq = stratify.unique() + for value in uniq: + cur = merged_df[merged_df["bigframes_stratify_col"] == value] + train, test = train_test_split( + cur, + test_size=test_size, + train_size=train_size, + random_state=random_state, + ) + train_dfs.append(train) + test_dfs.append(test) + + train_df = cast( + bpd.DataFrame, bpd.concat(train_dfs).drop(columns="bigframes_stratify_col") + ) + test_df = cast( + bpd.DataFrame, bpd.concat(test_dfs).drop(columns="bigframes_stratify_col") + ) + return [train_df, test_df] + + if stratify is None: + split_dfs = dfs[0]._split( + fracs=(train_size, test_size), random_state=random_state + ) + else: + split_dfs = _stratify_split(dfs[0], stratify) train_index = split_dfs[0].index test_index = split_dfs[1].index diff --git a/tests/system/small/ml/test_model_selection.py b/tests/system/small/ml/test_model_selection.py index 63d0840d29..ea9220feb4 100644 --- a/tests/system/small/ml/test_model_selection.py +++ b/tests/system/small/ml/test_model_selection.py @@ -234,3 +234,65 @@ def test_train_test_split_value_error(penguins_df_default_index, train_size, tes model_selection.train_test_split( X, y, train_size=train_size, test_size=test_size ) + + +def test_train_test_split_stratify(penguins_df_default_index): + X = penguins_df_default_index[ + [ + "species", + "island", + "culmen_length_mm", + ] + ] + y = penguins_df_default_index[["species"]] + X_train, X_test, y_train, y_test = model_selection.train_test_split( + X, y, stratify=penguins_df_default_index["species"] + ) + + # Original distribution is [152, 124, 68]. All the categories follow 75/25 split + train_counts = pd.Series( + [114, 93, 51], + index=pd.Index( + [ + "Adelie Penguin (Pygoscelis adeliae)", + "Gentoo penguin (Pygoscelis papua)", + "Chinstrap penguin (Pygoscelis antarctica)", + ], + name="species", + ), + dtype="Int64", + name="count", + ) + test_counts = pd.Series( + [38, 31, 17], + index=pd.Index( + [ + "Adelie Penguin (Pygoscelis adeliae)", + "Gentoo penguin (Pygoscelis papua)", + "Chinstrap penguin (Pygoscelis antarctica)", + ], + name="species", + ), + dtype="Int64", + name="count", + ) + pd.testing.assert_series_equal( + X_train["species"].value_counts().to_pandas(), + train_counts, + check_index_type=False, + ) + pd.testing.assert_series_equal( + X_test["species"].value_counts().to_pandas(), + test_counts, + check_index_type=False, + ) + pd.testing.assert_series_equal( + y_train["species"].value_counts().to_pandas(), + train_counts, + check_index_type=False, + ) + pd.testing.assert_series_equal( + y_test["species"].value_counts().to_pandas(), + test_counts, + check_index_type=False, + )