From 46367b9ce2604a762a0c5ec9322b63f61f1b9051 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 19 Mar 2024 23:25:39 +0000 Subject: [PATCH 1/9] feat: Add pivot_table for DataFrame. --- bigframes/dataframe.py | 50 +++++++++++++++++++ tests/system/small/test_dataframe.py | 28 +++++++++++ .../bigframes_vendored/pandas/core/frame.py | 26 ++++++++++ 3 files changed, 104 insertions(+) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index e8328b6047..4d606ad28f 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2068,6 +2068,56 @@ def pivot( ) -> DataFrame: return self._pivot(columns=columns, index=index, values=values) + def pivot_table( + self, + values: typing.Optional[ + typing.Union[blocks.Label, Sequence[blocks.Label]] + ] = None, + index: typing.Optional[ + typing.Union[blocks.Label, Sequence[blocks.Label]] + ] = None, + columns: typing.Union[blocks.Label, Sequence[blocks.Label]] = None, + aggfunc: str = "mean", + ) -> DataFrame: + if isinstance(index, blocks.Label): + index = [index] + if isinstance(columns, blocks.Label): + columns = [columns] + if isinstance(values, blocks.Label): + values = [values] + + keys = index + columns + agged = self.groupby(keys, dropna=True)[values].agg(aggfunc) + + if isinstance(agged, bigframes.series.Series): + agged = agged.to_frame() + + agged = agged.dropna(how="all") + + if len(values) == 1: + agged = agged.rename(columns={agged.columns[0]: values[0]}) + + agged = agged.reset_index() + + # Unlike pivot, pivot_table has values always ordered. + pivoted = agged.pivot( + columns=columns, + index=index, + values=sorted(values) if len(values) > 1 else None, + ).sort_index() + + # TODO: Remove the reordering step once the issue is resolved. + # The pivot_table method results in multi-index columns that are always ordered. + # However, the order of the pivoted result columns is not guaranteed to be sorted. + # Sort the multi-index columns + sorted_columns = pivoted.columns.sort_values() + pivoted = pivoted[sorted_columns] + + # This step ensures the column information is accurately maintained after sorting. + pivoted.columns = sorted_columns + + return pivoted + def stack(self, level: LevelsType = -1): if not isinstance(self.columns, pandas.MultiIndex): if level not in [0, -1, self.columns.name]: diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index be4211a2fc..e9beaf082e 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -2459,6 +2459,34 @@ def test_df_pivot_hockey(hockey_df, hockey_pandas_df, values, index, columns): pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) +@pytest.mark.parametrize( + ("values", "index", "columns", "aggfunc"), + [ + (["culmen_length_mm", "body_mass_g"], "species", "sex", "std"), + (["body_mass_g", "culmen_length_mm"], ["species", "island"], "sex", "sum"), + ("body_mass_g", "sex", ["island", "species"], "mean"), + ("culmen_depth_mm", "island", "species", "max"), + ], +) +def test_df_pivot_table( + penguins_df_default_index, + penguins_pandas_df_default_index, + values, + index, + columns, + aggfunc, +): + bf_result = penguins_df_default_index.pivot_table( + values=values, index=index, columns=columns, aggfunc=aggfunc + ).to_pandas() + pd_result = penguins_pandas_df_default_index.pivot_table( + values=values, index=index, columns=columns, aggfunc=aggfunc + ) + pd.testing.assert_frame_equal( + bf_result, pd_result, check_dtype=False, check_column_type=False + ) + + def test_ipython_key_completions_with_drop(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_names = "string_col" diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 7793b31a21..d63ed7d93e 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -4632,6 +4632,32 @@ def pivot(self, *, columns, index=None, values=None): """ raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def pivot_table(self, values=None, index=None, columns=None, aggfunc="mean"): + """ + Create a spreadsheet-style pivot table as a DataFrame. + + The levels in the pivot table will be stored in MultiIndex objects (hierarchical indexes) + on the index and columns of the result DataFrame. + + Args: + values (str, object or a list of the previous, optional): + Column(s) to use for populating new frame's values. If not + specified, all remaining columns will be used and the result will + have hierarchically indexed columns. + + index (str or object or a list of str, optional): + Column to use to make new frame's index. If not given, uses existing index. + + columns (str or object or a list of str): + Column to use to make new frame's columns. + aggfunc (str, default “mean”): + Aggregation function name to compute summary statistics (e.g., 'sum', 'mean'). + + Returns: + DataFrame: An Excel style pivot table. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def stack(self, level=-1): """ Stack the prescribed level(s) from columns to index. From da4155f2f63b58e9809786bd1e64689e1b84aee3 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 19 Mar 2024 23:30:39 +0000 Subject: [PATCH 2/9] Update logic --- bigframes/dataframe.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 4d606ad28f..840ad78bd0 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2081,10 +2081,20 @@ def pivot_table( ) -> DataFrame: if isinstance(index, blocks.Label): index = [index] + else: + index = list(index) + if isinstance(columns, blocks.Label): columns = [columns] + else: + columns = list(columns) + if isinstance(values, blocks.Label): values = [values] + else: + values = list(values) + + values.sort() keys = index + columns agged = self.groupby(keys, dropna=True)[values].agg(aggfunc) @@ -2103,7 +2113,7 @@ def pivot_table( pivoted = agged.pivot( columns=columns, index=index, - values=sorted(values) if len(values) > 1 else None, + values=values if len(values) > 1 else None, ).sort_index() # TODO: Remove the reordering step once the issue is resolved. From dabf5ec53d163c872dff923a8957fc07f1da07b5 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 19 Mar 2024 23:32:06 +0000 Subject: [PATCH 3/9] Update comments --- bigframes/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 840ad78bd0..a76202993a 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2094,6 +2094,7 @@ def pivot_table( else: values = list(values) + # Unlike pivot, pivot_table has values always ordered. values.sort() keys = index + columns @@ -2109,7 +2110,6 @@ def pivot_table( agged = agged.reset_index() - # Unlike pivot, pivot_table has values always ordered. pivoted = agged.pivot( columns=columns, index=index, From c6c1be5dc8bcddaf300b708725b0b272b06ca490 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 19 Mar 2024 23:35:43 +0000 Subject: [PATCH 4/9] Remove code unused after merge. --- bigframes/dataframe.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 08e4be4a33..847faa5c59 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2123,9 +2123,6 @@ def pivot_table( sorted_columns = pivoted.columns.sort_values() pivoted = pivoted[sorted_columns] - # This step ensures the column information is accurately maintained after sorting. - pivoted.columns = sorted_columns - return pivoted def stack(self, level: LevelsType = -1): From 5309e2d98dc5c428a115cc578c28d3300b691923 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 19 Mar 2024 23:37:31 +0000 Subject: [PATCH 5/9] Code update. --- bigframes/dataframe.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 847faa5c59..b1af1d13a7 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2119,11 +2119,8 @@ def pivot_table( # TODO: Remove the reordering step once the issue is resolved. # The pivot_table method results in multi-index columns that are always ordered. # However, the order of the pivoted result columns is not guaranteed to be sorted. - # Sort the multi-index columns - sorted_columns = pivoted.columns.sort_values() - pivoted = pivoted[sorted_columns] - - return pivoted + # Sort and reorder. + return pivoted[pivoted.columns.sort_values()] def stack(self, level: LevelsType = -1): if not isinstance(self.columns, pandas.MultiIndex): From 246a942fd86c06e2de300ecd1cc8040b1b1d3e4e Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Thu, 21 Mar 2024 00:19:48 +0000 Subject: [PATCH 6/9] Update code example. --- .../bigframes_vendored/pandas/core/frame.py | 57 ++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index c6626d2a9d..7caa2bae41 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -4639,6 +4639,61 @@ def pivot_table(self, values=None, index=None, columns=None, aggfunc="mean"): The levels in the pivot table will be stored in MultiIndex objects (hierarchical indexes) on the index and columns of the result DataFrame. + **Examples:** + + >>> import bigframes.pandas as bpd + >>> bpd.options.display.progress_bar = None + + >>> df = bpd.DataFrame({ + ... 'Product': ['Product A', 'Product B', 'Product A', 'Product B', 'Product A', 'Product B'], + ... 'Region': ['East', 'West', 'East', 'West', 'West', 'East'], + ... 'Sales': [100, 200, 150, 100, 200, 150], + ... 'Rating': [3, 5, 4, 3, 3, 5] + ... }) + >>> df + Product Region Sales Rating + 0 Product A East 100 3 + 1 Product B West 200 5 + 2 Product A East 150 4 + 3 Product B West 100 3 + 4 Product A West 200 3 + 5 Product B East 150 5 + + [6 rows x 4 columns] + + Using `pivot_table` with default aggfunc "mean": + + >>> pivot_table = df.pivot_table( + ... values=['Sales', 'Rating'], + ... index='Product', + ... columns='Region' + ... ) + >>> pivot_table + Rating Sales + Region East West East West + Product + Product A 3.5 3.0 125.0 200.0 + Product B 5.0 4.0 150.0 150.0 + + [2 rows x 4 columns] + + Using `pivot_table` with specified aggfunc "max": + + >>> pivot_table = df.pivot_table( + ... values=['Sales', 'Rating'], + ... index='Product', + ... columns='Region', + ... aggfunc="max" + ... ) + >>> pivot_table + Rating Sales + Region East West East West + Product + Product A 4 3 150 200 + Product B 5 5 150 200 + + [2 rows x 4 columns] + Args: values (str, object or a list of the previous, optional): Column(s) to use for populating new frame's values. If not @@ -4650,7 +4705,7 @@ def pivot_table(self, values=None, index=None, columns=None, aggfunc="mean"): columns (str or object or a list of str): Column to use to make new frame's columns. - aggfunc (str, default “mean”): + aggfunc (str, default "mean"): Aggregation function name to compute summary statistics (e.g., 'sum', 'mean'). Returns: From 49c171838780085de9944ed955a383658f10a8b7 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 9 Apr 2024 17:43:59 +0000 Subject: [PATCH 7/9] Update for Tuple type. --- bigframes/dataframe.py | 6 +++--- tests/system/small/test_dataframe.py | 4 ++-- third_party/bigframes_vendored/pandas/core/frame.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 6cba027c7e..cb19a9fb90 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2085,17 +2085,17 @@ def pivot_table( columns: typing.Union[blocks.Label, Sequence[blocks.Label]] = None, aggfunc: str = "mean", ) -> DataFrame: - if isinstance(index, blocks.Label): + if isinstance(index, blocks.Label) and index in self.columns: index = [index] else: index = list(index) - if isinstance(columns, blocks.Label): + if isinstance(columns, blocks.Label) and columns in self.columns: columns = [columns] else: columns = list(columns) - if isinstance(values, blocks.Label): + if isinstance(values, blocks.Label) and values in self.columns: values = [values] else: values = list(values) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 962b15f2bd..b8bddf1b09 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -2557,8 +2557,8 @@ def test_df_pivot_hockey(hockey_df, hockey_pandas_df, values, index, columns): @pytest.mark.parametrize( ("values", "index", "columns", "aggfunc"), [ - (["culmen_length_mm", "body_mass_g"], "species", "sex", "std"), - (["body_mass_g", "culmen_length_mm"], ["species", "island"], "sex", "sum"), + (("culmen_length_mm", "body_mass_g"), "species", "sex", "std"), + (["body_mass_g", "culmen_length_mm"], ("species", "island"), "sex", "sum"), ("body_mass_g", "sex", ["island", "species"], "mean"), ("culmen_depth_mm", "island", "species", "max"), ], diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index a35ed66f56..25080bd6fe 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -4732,6 +4732,7 @@ def pivot_table(self, values=None, index=None, columns=None, aggfunc="mean"): columns (str or object or a list of str): Column to use to make new frame's columns. + aggfunc (str, default "mean"): Aggregation function name to compute summary statistics (e.g., 'sum', 'mean'). From 6057c4083e8bcf5e7d218d46c3ac692dc3078a39 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 9 Apr 2024 18:14:32 +0000 Subject: [PATCH 8/9] Update code logic --- bigframes/dataframe.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index ecf89d2ce0..6eca5243a3 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2143,20 +2143,20 @@ def pivot_table( columns: typing.Union[blocks.Label, Sequence[blocks.Label]] = None, aggfunc: str = "mean", ) -> DataFrame: - if isinstance(index, blocks.Label) and index in self.columns: - index = [index] - else: + if isinstance(index, Iterable) and not (isinstance(index, blocks.Label) and index in self.columns): index = list(index) - - if isinstance(columns, blocks.Label) and columns in self.columns: - columns = [columns] else: - columns = list(columns) + index = [index] - if isinstance(values, blocks.Label) and values in self.columns: - values = [values] + if isinstance(columns, Iterable) and not (isinstance(columns, blocks.Label) and columns in self.columns): + columns = list(columns) else: + columns = [columns] + + if isinstance(values, Iterable) and not (isinstance(values, blocks.Label) and values in self.columns): values = list(values) + else: + values = [values] # Unlike pivot, pivot_table has values always ordered. values.sort() From 777a14bfb8b542e89d65d97a131b1db8280583bf Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 9 Apr 2024 18:14:56 +0000 Subject: [PATCH 9/9] Update format --- bigframes/dataframe.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 6eca5243a3..32f5a36f79 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2143,17 +2143,23 @@ def pivot_table( columns: typing.Union[blocks.Label, Sequence[blocks.Label]] = None, aggfunc: str = "mean", ) -> DataFrame: - if isinstance(index, Iterable) and not (isinstance(index, blocks.Label) and index in self.columns): + if isinstance(index, Iterable) and not ( + isinstance(index, blocks.Label) and index in self.columns + ): index = list(index) else: index = [index] - if isinstance(columns, Iterable) and not (isinstance(columns, blocks.Label) and columns in self.columns): + if isinstance(columns, Iterable) and not ( + isinstance(columns, blocks.Label) and columns in self.columns + ): columns = list(columns) else: columns = [columns] - if isinstance(values, Iterable) and not (isinstance(values, blocks.Label) and values in self.columns): + if isinstance(values, Iterable) and not ( + isinstance(values, blocks.Label) and values in self.columns + ): values = list(values) else: values = [values]