diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 614a2fb919..efbe56abf7 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2241,6 +2241,63 @@ def itertuples( for item in df.itertuples(index=index, name=name): yield item + def where(self, cond, other=None): + if isinstance(other, bigframes.series.Series): + raise ValueError("Seires is not a supported replacement type!") + + if self.columns.nlevels > 1 or self.index.nlevels > 1: + raise NotImplementedError( + "The dataframe.where() method does not support multi-index and/or multi-column." + ) + + aligned_block, (_, _) = self._block.join(cond._block, how="left") + # No left join is needed when 'other' is None or constant. + if isinstance(other, bigframes.dataframe.DataFrame): + aligned_block, (_, _) = aligned_block.join(other._block, how="left") + self_len = len(self._block.value_columns) + cond_len = len(cond._block.value_columns) + + ids = aligned_block.value_columns[:self_len] + labels = aligned_block.column_labels[:self_len] + self_col = {x: ex.deref(y) for x, y in zip(labels, ids)} + + if isinstance(cond, bigframes.series.Series) and cond.name in self_col: + # This is when 'cond' is a valid series. + y = aligned_block.value_columns[self_len] + cond_col = {x: ex.deref(y) for x in self_col.keys()} + else: + # This is when 'cond' is a dataframe. + ids = aligned_block.value_columns[self_len : self_len + cond_len] + labels = aligned_block.column_labels[self_len : self_len + cond_len] + cond_col = {x: ex.deref(y) for x, y in zip(labels, ids)} + + if isinstance(other, DataFrame): + other_len = len(self._block.value_columns) + ids = aligned_block.value_columns[-other_len:] + labels = aligned_block.column_labels[-other_len:] + other_col = {x: ex.deref(y) for x, y in zip(labels, ids)} + else: + # This is when 'other' is None or constant. + labels = aligned_block.column_labels[:self_len] + other_col = {x: ex.const(other) for x in labels} # type: ignore + + result_series = {} + for x, self_id in self_col.items(): + cond_id = cond_col[x] if x in cond_col else ex.const(False) + other_id = other_col[x] if x in other_col else ex.const(None) + result_block, result_id = aligned_block.project_expr( + ops.where_op.as_expr(self_id, cond_id, other_id) + ) + series = bigframes.series.Series( + result_block.select_column(result_id).with_column_labels([x]) + ) + result_series[x] = series + + result = DataFrame(result_series) + result.columns.name = self.columns.name + result.columns.names = self.columns.names + return result + def dropna( self, *, diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index bbcec90ea8..bae71b33be 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -322,6 +322,114 @@ def test_insert(scalars_dfs, loc, column, value, allow_duplicates): pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df, check_dtype=False) +def test_where_series_cond(scalars_df_index, scalars_pandas_df_index): + # Condition is dataframe, other is None (as default). + cond_bf = scalars_df_index["int64_col"] > 0 + cond_pd = scalars_pandas_df_index["int64_col"] > 0 + bf_result = scalars_df_index.where(cond_bf).to_pandas() + pd_result = scalars_pandas_df_index.where(cond_pd) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index): + # Test when a dataframe has multi-index or multi-columns. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + + dataframe_bf.columns = pd.MultiIndex.from_tuples( + [("str1", 1), ("str2", 2)], names=["STR", "INT"] + ) + cond_bf = dataframe_bf["str1"] > 0 + + with pytest.raises(NotImplementedError) as context: + dataframe_bf.where(cond_bf).to_pandas() + assert ( + str(context.value) + == "The dataframe.where() method does not support multi-index and/or multi-column." + ) + + +def test_where_series_cond_const_other(scalars_df_index, scalars_pandas_df_index): + # Condition is a series, other is a constant. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + dataframe_bf.columns.name = "test_name" + dataframe_pd.columns.name = "test_name" + + cond_bf = dataframe_bf["int64_col"] > 0 + cond_pd = dataframe_pd["int64_col"] > 0 + other = 0 + + bf_result = dataframe_bf.where(cond_bf, other).to_pandas() + pd_result = dataframe_pd.where(cond_pd, other) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_series_cond_dataframe_other(scalars_df_index, scalars_pandas_df_index): + # Condition is a series, other is a dataframe. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond_bf = dataframe_bf["int64_col"] > 0 + cond_pd = dataframe_pd["int64_col"] > 0 + other_bf = -dataframe_bf + other_pd = -dataframe_pd + + bf_result = dataframe_bf.where(cond_bf, other_bf).to_pandas() + pd_result = dataframe_pd.where(cond_pd, other_pd) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_dataframe_cond(scalars_df_index, scalars_pandas_df_index): + # Condition is a dataframe, other is None. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond_bf = dataframe_bf > 0 + cond_pd = dataframe_pd > 0 + + bf_result = dataframe_bf.where(cond_bf, None).to_pandas() + pd_result = dataframe_pd.where(cond_pd, None) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_dataframe_cond_const_other(scalars_df_index, scalars_pandas_df_index): + # Condition is a dataframe, other is a constant. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond_bf = dataframe_bf > 0 + cond_pd = dataframe_pd > 0 + other_bf = 10 + other_pd = 10 + + bf_result = dataframe_bf.where(cond_bf, other_bf).to_pandas() + pd_result = dataframe_pd.where(cond_pd, other_pd) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_dataframe_cond_dataframe_other( + scalars_df_index, scalars_pandas_df_index +): + # Condition is a dataframe, other is a dataframe. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond_bf = dataframe_bf > 0 + cond_pd = dataframe_pd > 0 + other_bf = dataframe_bf * 2 + other_pd = dataframe_pd * 2 + + bf_result = dataframe_bf.where(cond_bf, other_bf).to_pandas() + pd_result = dataframe_pd.where(cond_pd, other_pd) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + def test_drop_column(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_name = "int64_col" diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 70da1a5c4c..053ed7b94c 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -1956,6 +1956,98 @@ def items(self): """ raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def where(self, cond, other): + """Replace values where the condition is False. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> bpd.options.display.progress_bar = None + + >>> df = bpd.DataFrame({'a': [20, 10, 0], 'b': [0, 10, 20]}) + >>> df + a b + 0 20 0 + 1 10 10 + 2 0 20 + + [3 rows x 2 columns] + + You can filter the values in the dataframe based on a condition. The + values matching the condition would be kept, and not matching would be + replaced. The default replacement value is ``NA``. For example, when the + condition is a dataframe: + + >>> df.where(df > 0) + a b + 0 20 + 1 10 10 + 2 20 + + [3 rows x 2 columns] + + You can specify a custom replacement value for non-matching values. + + >>> df.where(df > 0, -1) + a b + 0 20 -1 + 1 10 10 + 2 -1 20 + + [3 rows x 2 columns] + + Besides dataframe, the condition can be a series too. For example: + + >>> df.where(df['a'] > 10, -1) + a b + 0 20 0 + 1 -1 -1 + 2 -1 -1 + + [3 rows x 2 columns] + + As for the replacement, it can be a dataframe too. For example: + + >>> df.where(df > 10, -df) + a b + 0 20 0 + 1 -10 -10 + 2 0 20 + + [3 rows x 2 columns] + + >>> df.where(df['a'] > 10, -df) + a b + 0 20 0 + 1 -10 -10 + 2 0 -20 + + [3 rows x 2 columns] + + Please note, replacement doesn't support Series for now. In pandas, when + specifying a Series as replacement, the axis value should be specified + at the same time, which is not supported in bigframes DataFrame. + + Args: + cond (bool Series/DataFrame, array-like, or callable): + Where cond is True, keep the original value. Where False, replace + with corresponding value from other. If cond is callable, it is + computed on the Series/DataFrame and returns boolean + Series/DataFrame or array. The callable must not change input + Series/DataFrame (though pandas doesn’t check it). + other (scalar, DataFrame, or callable): + Entries where cond is False are replaced with corresponding value + from other. If other is callable, it is computed on the + DataFrame and returns scalar or DataFrame. The callable must not + change input DataFrame (though pandas doesn’t check it). If not + specified, entries will be filled with the corresponding NULL + value (np.nan for numpy dtypes, pd.NA for extension dtypes). + + Returns: + DataFrame: DataFrame after the replacement. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + # ---------------------------------------------------------------------- # Sorting