From 8615e373a94327901bcf1ddf1300b7955aa16b6d Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Tue, 2 Jul 2024 20:08:28 +0000 Subject: [PATCH 1/3] feat: add stratify param to ml.model_selection.train_test_split --- bigframes/ml/model_selection.py | 33 +++++++++- tests/system/small/ml/test_model_selection.py | 62 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/bigframes/ml/model_selection.py b/bigframes/ml/model_selection.py index 48eb5a93a7..9ab4290cf0 100644 --- a/bigframes/ml/model_selection.py +++ b/bigframes/ml/model_selection.py @@ -29,6 +29,7 @@ def train_test_split( test_size: Union[float, None] = None, train_size: Union[float, None] = None, random_state: Union[int, None] = None, + stratify: Union[bpd.Series, None] = None, ) -> List[Union[bpd.DataFrame, bpd.Series]]: """Splits dataframes or series into random train and test subsets. @@ -46,6 +47,9 @@ def train_test_split( random_state (default None): A seed to use for randomly choosing the rows of the split. If not set, a random split will be generated each time. + stratify: (bigframes.series.Series or None, default None): + If not None, data is split in a stratified fashion, using this as the class labels. Each split has the same distribution of the class labels with the original dataset. + Default to None. Returns: List[Union[bigframes.dataframe.DataFrame, bigframes.series.Series]]: A list of BigQuery DataFrames or Series. @@ -76,7 +80,34 @@ def train_test_split( dfs = list(utils.convert_to_dataframe(*arrays)) - split_dfs = dfs[0]._split(fracs=(train_size, test_size), random_state=random_state) + def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFrame]: + """Split a single DF accoding to the stratify Series.""" + stratify = stratify.rename("bigframes_stratify_col") # avoid name conflicts + merged_df = df.join(stratify.to_frame(), how="outer") + + train_dfs, test_dfs = [], [] + uniq = stratify.unique() + for value in uniq: + cur = merged_df[merged_df["bigframes_stratify_col"] == value] + train, test = train_test_split( + cur, + test_size=test_size, + train_size=train_size, + random_state=random_state, + ) + train_dfs.append(train) + test_dfs.append(test) + + train_df = bpd.concat(train_dfs).drop(columns="bigframes_stratify_col") + test_df = bpd.concat(test_dfs).drop(columns="bigframes_stratify_col") + return [train_df, test_df] + + if stratify is None: + split_dfs = dfs[0]._split( + fracs=(train_size, test_size), random_state=random_state + ) + else: + split_dfs = _stratify_split(dfs[0], stratify) train_index = split_dfs[0].index test_index = split_dfs[1].index diff --git a/tests/system/small/ml/test_model_selection.py b/tests/system/small/ml/test_model_selection.py index 63d0840d29..ea9220feb4 100644 --- a/tests/system/small/ml/test_model_selection.py +++ b/tests/system/small/ml/test_model_selection.py @@ -234,3 +234,65 @@ def test_train_test_split_value_error(penguins_df_default_index, train_size, tes model_selection.train_test_split( X, y, train_size=train_size, test_size=test_size ) + + +def test_train_test_split_stratify(penguins_df_default_index): + X = penguins_df_default_index[ + [ + "species", + "island", + "culmen_length_mm", + ] + ] + y = penguins_df_default_index[["species"]] + X_train, X_test, y_train, y_test = model_selection.train_test_split( + X, y, stratify=penguins_df_default_index["species"] + ) + + # Original distribution is [152, 124, 68]. All the categories follow 75/25 split + train_counts = pd.Series( + [114, 93, 51], + index=pd.Index( + [ + "Adelie Penguin (Pygoscelis adeliae)", + "Gentoo penguin (Pygoscelis papua)", + "Chinstrap penguin (Pygoscelis antarctica)", + ], + name="species", + ), + dtype="Int64", + name="count", + ) + test_counts = pd.Series( + [38, 31, 17], + index=pd.Index( + [ + "Adelie Penguin (Pygoscelis adeliae)", + "Gentoo penguin (Pygoscelis papua)", + "Chinstrap penguin (Pygoscelis antarctica)", + ], + name="species", + ), + dtype="Int64", + name="count", + ) + pd.testing.assert_series_equal( + X_train["species"].value_counts().to_pandas(), + train_counts, + check_index_type=False, + ) + pd.testing.assert_series_equal( + X_test["species"].value_counts().to_pandas(), + test_counts, + check_index_type=False, + ) + pd.testing.assert_series_equal( + y_train["species"].value_counts().to_pandas(), + train_counts, + check_index_type=False, + ) + pd.testing.assert_series_equal( + y_test["species"].value_counts().to_pandas(), + test_counts, + check_index_type=False, + ) From 82d34d3d5e0e387bbf465933b24ebe28c70f91b3 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Tue, 2 Jul 2024 20:30:50 +0000 Subject: [PATCH 2/3] fix mypy --- bigframes/ml/model_selection.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/bigframes/ml/model_selection.py b/bigframes/ml/model_selection.py index 9ab4290cf0..a308592345 100644 --- a/bigframes/ml/model_selection.py +++ b/bigframes/ml/model_selection.py @@ -18,7 +18,7 @@ import typing -from typing import List, Union +from typing import cast, List, Union from bigframes.ml import utils import bigframes.pandas as bpd @@ -98,8 +98,12 @@ def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFra train_dfs.append(train) test_dfs.append(test) - train_df = bpd.concat(train_dfs).drop(columns="bigframes_stratify_col") - test_df = bpd.concat(test_dfs).drop(columns="bigframes_stratify_col") + train_df = cast( + bpd.DataFrame, bpd.concat(train_dfs).drop(columns="bigframes_stratify_col") + ) + test_df = cast( + bpd.DataFrame, bpd.concat(test_dfs).drop(columns="bigframes_stratify_col") + ) return [train_df, test_df] if stratify is None: From 2e64e0d2b3ca0ae4561145fbe30fde62233c7319 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Wed, 10 Jul 2024 18:51:08 +0000 Subject: [PATCH 3/3] add notes for limit --- bigframes/ml/model_selection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bigframes/ml/model_selection.py b/bigframes/ml/model_selection.py index a308592345..6220e899ae 100644 --- a/bigframes/ml/model_selection.py +++ b/bigframes/ml/model_selection.py @@ -50,6 +50,7 @@ def train_test_split( stratify: (bigframes.series.Series or None, default None): If not None, data is split in a stratified fashion, using this as the class labels. Each split has the same distribution of the class labels with the original dataset. Default to None. + Note: By setting the stratify parameter, the memory consumption and generated SQL will be linear to the unique values in the Series. May return errors if the unique values size is too large. Returns: List[Union[bigframes.dataframe.DataFrame, bigframes.series.Series]]: A list of BigQuery DataFrames or Series.