diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index e1bb885558..647d174f44 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2359,6 +2359,9 @@ def where(self, cond, other=None): result.columns.names = self.columns.names return result + def mask(self, cond, other=None): + return self.where(~cond, other=other) + def dropna( self, *, diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 5b94df2446..a097f8c64d 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -331,6 +331,17 @@ def test_where_series_cond(scalars_df_index, scalars_pandas_df_index): pandas.testing.assert_frame_equal(bf_result, pd_result) +def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index): + cond_bf = scalars_df_index["int64_col"] > 0 + cond_pd = scalars_pandas_df_index["int64_col"] > 0 + + bf_df = scalars_df_index[["int64_too", "int64_col", "float64_col"]] + pd_df = scalars_pandas_df_index[["int64_too", "int64_col", "float64_col"]] + bf_result = bf_df.mask(cond_bf, bf_df + 1).to_pandas() + pd_result = pd_df.mask(cond_pd, pd_df + 1) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index): # Test when a dataframe has multi-index or multi-columns. columns = ["int64_col", "float64_col"] diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index a44d6b629f..66dbff1b6a 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -2048,6 +2048,98 @@ def where(self, cond, other): """ raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def mask(self, cond, other): + """Replace values where the condition is False. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> bpd.options.display.progress_bar = None + + >>> df = bpd.DataFrame({'a': [20, 10, 0], 'b': [0, 10, 20]}) + >>> df + a b + 0 20 0 + 1 10 10 + 2 0 20 + + [3 rows x 2 columns] + + You can filter the values in the dataframe based on a condition. The + values matching the condition would be kept, and not matching would be + replaced. The default replacement value is ``NA``. For example, when the + condition is a dataframe: + + >>> df.mask(df > 0) + a b + 0 0 + 1 + 2 0 + + [3 rows x 2 columns] + + You can specify a custom replacement value for non-matching values. + + >>> df.mask(df > 0, -1) + a b + 0 -1 0 + 1 -1 -1 + 2 0 -1 + + [3 rows x 2 columns] + + Besides dataframe, the condition can be a series too. For example: + + >>> df.mask(df['a'] > 10, -1) + a b + 0 -1 -1 + 1 10 10 + 2 0 20 + + [3 rows x 2 columns] + + As for the replacement, it can be a dataframe too. For example: + + >>> df.mask(df > 10, -df) + a b + 0 -20 0 + 1 10 10 + 2 0 -20 + + [3 rows x 2 columns] + + >>> df.mask(df['a'] > 10, -df) + a b + 0 -20 0 + 1 10 10 + 2 0 20 + + [3 rows x 2 columns] + + Please note, replacement doesn't support Series for now. In pandas, when + specifying a Series as replacement, the axis value should be specified + at the same time, which is not supported in bigframes DataFrame. + + Args: + cond (bool Series/DataFrame, array-like, or callable): + Where cond is False, keep the original value. Where True, replace + with corresponding value from other. If cond is callable, it is + computed on the Series/DataFrame and returns boolean + Series/DataFrame or array. The callable must not change input + Series/DataFrame (though pandas doesn’t check it). + other (scalar, DataFrame, or callable): + Entries where cond is True are replaced with corresponding value + from other. If other is callable, it is computed on the + DataFrame and returns scalar or DataFrame. The callable must not + change input DataFrame (though pandas doesn’t check it). If not + specified, entries will be filled with the corresponding NULL + value (np.nan for numpy dtypes, pd.NA for extension dtypes). + + Returns: + DataFrame: DataFrame after the replacement. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + # ---------------------------------------------------------------------- # Sorting