diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 7b282783bd..32f5a36f79 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2132,6 +2132,66 @@ def pivot( ) -> DataFrame: return self._pivot(columns=columns, index=index, values=values) + def pivot_table( + self, + values: typing.Optional[ + typing.Union[blocks.Label, Sequence[blocks.Label]] + ] = None, + index: typing.Optional[ + typing.Union[blocks.Label, Sequence[blocks.Label]] + ] = None, + columns: typing.Union[blocks.Label, Sequence[blocks.Label]] = None, + aggfunc: str = "mean", + ) -> DataFrame: + if isinstance(index, Iterable) and not ( + isinstance(index, blocks.Label) and index in self.columns + ): + index = list(index) + else: + index = [index] + + if isinstance(columns, Iterable) and not ( + isinstance(columns, blocks.Label) and columns in self.columns + ): + columns = list(columns) + else: + columns = [columns] + + if isinstance(values, Iterable) and not ( + isinstance(values, blocks.Label) and values in self.columns + ): + values = list(values) + else: + values = [values] + + # Unlike pivot, pivot_table has values always ordered. + values.sort() + + keys = index + columns + agged = self.groupby(keys, dropna=True)[values].agg(aggfunc) + + if isinstance(agged, bigframes.series.Series): + agged = agged.to_frame() + + agged = agged.dropna(how="all") + + if len(values) == 1: + agged = agged.rename(columns={agged.columns[0]: values[0]}) + + agged = agged.reset_index() + + pivoted = agged.pivot( + columns=columns, + index=index, + values=values if len(values) > 1 else None, + ).sort_index() + + # TODO: Remove the reordering step once the issue is resolved. + # The pivot_table method results in multi-index columns that are always ordered. + # However, the order of the pivoted result columns is not guaranteed to be sorted. + # Sort and reorder. + return pivoted[pivoted.columns.sort_values()] + def stack(self, level: LevelsType = -1): if not isinstance(self.columns, pandas.MultiIndex): if level not in [0, -1, self.columns.name]: diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 0811defbc1..ba205078ed 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -2606,6 +2606,34 @@ def test_df_pivot_hockey(hockey_df, hockey_pandas_df, values, index, columns): pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) +@pytest.mark.parametrize( + ("values", "index", "columns", "aggfunc"), + [ + (("culmen_length_mm", "body_mass_g"), "species", "sex", "std"), + (["body_mass_g", "culmen_length_mm"], ("species", "island"), "sex", "sum"), + ("body_mass_g", "sex", ["island", "species"], "mean"), + ("culmen_depth_mm", "island", "species", "max"), + ], +) +def test_df_pivot_table( + penguins_df_default_index, + penguins_pandas_df_default_index, + values, + index, + columns, + aggfunc, +): + bf_result = penguins_df_default_index.pivot_table( + values=values, index=index, columns=columns, aggfunc=aggfunc + ).to_pandas() + pd_result = penguins_pandas_df_default_index.pivot_table( + values=values, index=index, columns=columns, aggfunc=aggfunc + ) + pd.testing.assert_frame_equal( + bf_result, pd_result, check_dtype=False, check_column_type=False + ) + + def test_ipython_key_completions_with_drop(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_names = "string_col" diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index ed615000c1..1fc80449d1 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -4711,6 +4711,88 @@ def pivot(self, *, columns, index=None, values=None): """ raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def pivot_table(self, values=None, index=None, columns=None, aggfunc="mean"): + """ + Create a spreadsheet-style pivot table as a DataFrame. + + The levels in the pivot table will be stored in MultiIndex objects (hierarchical indexes) + on the index and columns of the result DataFrame. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> bpd.options.display.progress_bar = None + + >>> df = bpd.DataFrame({ + ... 'Product': ['Product A', 'Product B', 'Product A', 'Product B', 'Product A', 'Product B'], + ... 'Region': ['East', 'West', 'East', 'West', 'West', 'East'], + ... 'Sales': [100, 200, 150, 100, 200, 150], + ... 'Rating': [3, 5, 4, 3, 3, 5] + ... }) + >>> df + Product Region Sales Rating + 0 Product A East 100 3 + 1 Product B West 200 5 + 2 Product A East 150 4 + 3 Product B West 100 3 + 4 Product A West 200 3 + 5 Product B East 150 5 + + [6 rows x 4 columns] + + Using `pivot_table` with default aggfunc "mean": + + >>> pivot_table = df.pivot_table( + ... values=['Sales', 'Rating'], + ... index='Product', + ... columns='Region' + ... ) + >>> pivot_table + Rating Sales + Region East West East West + Product + Product A 3.5 3.0 125.0 200.0 + Product B 5.0 4.0 150.0 150.0 + + [2 rows x 4 columns] + + Using `pivot_table` with specified aggfunc "max": + + >>> pivot_table = df.pivot_table( + ... values=['Sales', 'Rating'], + ... index='Product', + ... columns='Region', + ... aggfunc="max" + ... ) + >>> pivot_table + Rating Sales + Region East West East West + Product + Product A 4 3 150 200 + Product B 5 5 150 200 + + [2 rows x 4 columns] + + Args: + values (str, object or a list of the previous, optional): + Column(s) to use for populating new frame's values. If not + specified, all remaining columns will be used and the result will + have hierarchically indexed columns. + + index (str or object or a list of str, optional): + Column to use to make new frame's index. If not given, uses existing index. + + columns (str or object or a list of str): + Column to use to make new frame's columns. + + aggfunc (str, default "mean"): + Aggregation function name to compute summary statistics (e.g., 'sum', 'mean'). + + Returns: + DataFrame: An Excel style pivot table. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def stack(self, level=-1): """ Stack the prescribed level(s) from columns to index.