Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit c4cb39d

Browse filesBrowse files
feat: Add agg/aggregate methods to windows (#2288)
1 parent 2dcf6ae commit c4cb39d
Copy full SHA for c4cb39d

3 files changed

+248-40Lines changed: 248 additions & 40 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎bigframes/core/window/rolling.py‎

Copy file name to clipboardExpand all lines: bigframes/core/window/rolling.py
+130-40Lines changed: 130 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,24 @@
1515
from __future__ import annotations
1616

1717
import datetime
18-
import typing
18+
from typing import Literal, Mapping, Sequence, TYPE_CHECKING, Union
1919

2020
import bigframes_vendored.pandas.core.window.rolling as vendored_pandas_rolling
2121
import numpy
2222
import pandas
2323

2424
from bigframes import dtypes
25+
from bigframes.core import agg_expressions
2526
from bigframes.core import expression as ex
26-
from bigframes.core import log_adapter, ordering, window_spec
27+
from bigframes.core import log_adapter, ordering, utils, window_spec
2728
import bigframes.core.blocks as blocks
2829
from bigframes.core.window import ordering as window_ordering
2930
import bigframes.operations.aggregations as agg_ops
3031

32+
if TYPE_CHECKING:
33+
import bigframes.dataframe as df
34+
import bigframes.series as series
35+
3136

3237
@log_adapter.class_logger
3338
class Window(vendored_pandas_rolling.Window):
@@ -37,7 +42,7 @@ def __init__(
3742
self,
3843
block: blocks.Block,
3944
window_spec: window_spec.WindowSpec,
40-
value_column_ids: typing.Sequence[str],
45+
value_column_ids: Sequence[str],
4146
drop_null_groups: bool = True,
4247
is_series: bool = False,
4348
skip_agg_column_id: str | None = None,
@@ -52,55 +57,106 @@ def __init__(
5257
self._skip_agg_column_id = skip_agg_column_id
5358

5459
def count(self):
55-
return self._apply_aggregate(agg_ops.count_op)
60+
return self._apply_aggregate_op(agg_ops.count_op)
5661

5762
def sum(self):
58-
return self._apply_aggregate(agg_ops.sum_op)
63+
return self._apply_aggregate_op(agg_ops.sum_op)
5964

6065
def mean(self):
61-
return self._apply_aggregate(agg_ops.mean_op)
66+
return self._apply_aggregate_op(agg_ops.mean_op)
6267

6368
def var(self):
64-
return self._apply_aggregate(agg_ops.var_op)
69+
return self._apply_aggregate_op(agg_ops.var_op)
6570

6671
def std(self):
67-
return self._apply_aggregate(agg_ops.std_op)
72+
return self._apply_aggregate_op(agg_ops.std_op)
6873

6974
def max(self):
70-
return self._apply_aggregate(agg_ops.max_op)
75+
return self._apply_aggregate_op(agg_ops.max_op)
7176

7277
def min(self):
73-
return self._apply_aggregate(agg_ops.min_op)
78+
return self._apply_aggregate_op(agg_ops.min_op)
7479

75-
def _apply_aggregate(
76-
self,
77-
op: agg_ops.UnaryAggregateOp,
78-
):
79-
agg_block = self._aggregate_block(op)
80+
def agg(self, func) -> Union[df.DataFrame, series.Series]:
81+
if utils.is_dict_like(func):
82+
return self._agg_dict(func)
83+
elif utils.is_list_like(func):
84+
return self._agg_list(func)
85+
else:
86+
return self._agg_func(func)
8087

81-
if self._is_series:
82-
from bigframes.series import Series
88+
aggregate = agg
89+
90+
def _agg_func(self, func) -> df.DataFrame:
91+
ids, labels = self._aggregated_columns()
92+
aggregations = [agg(col_id, agg_ops.lookup_agg_func(func)[0]) for col_id in ids]
93+
return self._apply_aggs(aggregations, labels)
94+
95+
def _agg_dict(self, func: Mapping) -> df.DataFrame:
96+
aggregations: list[agg_expressions.Aggregation] = []
97+
column_labels = []
98+
function_labels = []
8399

84-
return Series(agg_block)
100+
want_aggfunc_level = any(utils.is_list_like(aggs) for aggs in func.values())
101+
102+
for label, funcs_for_id in func.items():
103+
col_id = self._block.label_to_col_id[label][-1] # get last matching column
104+
func_list = (
105+
funcs_for_id if utils.is_list_like(funcs_for_id) else [funcs_for_id]
106+
)
107+
for f in func_list:
108+
f_op, f_label = agg_ops.lookup_agg_func(f)
109+
aggregations.append(agg(col_id, f_op))
110+
column_labels.append(label)
111+
function_labels.append(f_label)
112+
if want_aggfunc_level:
113+
result_labels: pandas.Index = utils.combine_indices(
114+
pandas.Index(column_labels),
115+
pandas.Index(function_labels),
116+
)
85117
else:
86-
from bigframes.dataframe import DataFrame
118+
result_labels = pandas.Index(column_labels)
87119

88-
# Preserve column order.
89-
column_labels = [
90-
self._block.col_id_to_label[col_id] for col_id in self._value_column_ids
91-
]
92-
return DataFrame(agg_block)._reindex_columns(column_labels)
120+
return self._apply_aggs(aggregations, result_labels)
93121

94-
def _aggregate_block(self, op: agg_ops.UnaryAggregateOp) -> blocks.Block:
95-
agg_col_ids = [
96-
col_id
97-
for col_id in self._value_column_ids
98-
if col_id != self._skip_agg_column_id
122+
def _agg_list(self, func: Sequence) -> df.DataFrame:
123+
ids, labels = self._aggregated_columns()
124+
aggregations = [
125+
agg(col_id, agg_ops.lookup_agg_func(f)[0]) for col_id in ids for f in func
99126
]
100-
block, result_ids = self._block.multi_apply_window_op(
101-
agg_col_ids,
102-
op,
103-
self._window_spec,
127+
128+
if self._is_series:
129+
# if series, no need to rebuild
130+
result_cols_idx = pandas.Index(
131+
[agg_ops.lookup_agg_func(f)[1] for f in func]
132+
)
133+
else:
134+
if self._block.column_labels.nlevels > 1:
135+
# Restructure MultiIndex for proper format: (idx1, idx2, func)
136+
# rather than ((idx1, idx2), func).
137+
column_labels = [
138+
tuple(label) + (agg_ops.lookup_agg_func(f)[1],)
139+
for label in labels.to_frame(index=False).to_numpy()
140+
for f in func
141+
]
142+
else: # Single-level index
143+
column_labels = [
144+
(label, agg_ops.lookup_agg_func(f)[1])
145+
for label in labels
146+
for f in func
147+
]
148+
result_cols_idx = pandas.MultiIndex.from_tuples(
149+
column_labels, names=[*self._block.column_labels.names, None]
150+
)
151+
return self._apply_aggs(aggregations, result_cols_idx)
152+
153+
def _apply_aggs(
154+
self, exprs: Sequence[agg_expressions.Aggregation], labels: pandas.Index
155+
):
156+
block, ids = self._block.apply_analytic(
157+
agg_exprs=exprs,
158+
window=self._window_spec,
159+
result_labels=labels,
104160
skip_null_groups=self._drop_null_groups,
105161
)
106162

@@ -115,24 +171,50 @@ def _aggregate_block(self, op: agg_ops.UnaryAggregateOp) -> blocks.Block:
115171
)
116172
block = block.set_index(col_ids=index_ids)
117173

118-
labels = [self._block.col_id_to_label[col] for col in agg_col_ids]
119174
if self._skip_agg_column_id is not None:
120-
result_ids = [self._skip_agg_column_id, *result_ids]
121-
labels.insert(0, self._block.col_id_to_label[self._skip_agg_column_id])
175+
block = block.select_columns([self._skip_agg_column_id, *ids])
176+
else:
177+
block = block.select_columns(ids).with_column_labels(labels)
178+
179+
if self._is_series and (len(block.value_columns) == 1):
180+
import bigframes.series as series
181+
182+
return series.Series(block)
183+
else:
184+
import bigframes.dataframe as df
185+
186+
return df.DataFrame(block)
187+
188+
def _apply_aggregate_op(
189+
self,
190+
op: agg_ops.UnaryAggregateOp,
191+
):
192+
ids, labels = self._aggregated_columns()
193+
aggregations = [agg(col_id, op) for col_id in ids]
194+
return self._apply_aggs(aggregations, labels)
122195

123-
return block.select_columns(result_ids).with_column_labels(labels)
196+
def _aggregated_columns(self) -> tuple[Sequence[str], pandas.Index]:
197+
agg_col_ids = [
198+
col_id
199+
for col_id in self._value_column_ids
200+
if col_id != self._skip_agg_column_id
201+
]
202+
labels: pandas.Index = pandas.Index(
203+
[self._block.col_id_to_label[col] for col in agg_col_ids]
204+
)
205+
return agg_col_ids, labels
124206

125207

126208
def create_range_window(
127209
block: blocks.Block,
128210
window: pandas.Timedelta | numpy.timedelta64 | datetime.timedelta | str,
129211
*,
130-
value_column_ids: typing.Sequence[str] = tuple(),
212+
value_column_ids: Sequence[str] = tuple(),
131213
min_periods: int | None,
132214
on: str | None = None,
133-
closed: typing.Literal["right", "left", "both", "neither"],
215+
closed: Literal["right", "left", "both", "neither"],
134216
is_series: bool,
135-
grouping_keys: typing.Sequence[str] = tuple(),
217+
grouping_keys: Sequence[str] = tuple(),
136218
drop_null_groups: bool = True,
137219
) -> Window:
138220

@@ -184,3 +266,11 @@ def create_range_window(
184266
skip_agg_column_id=None if on is None else rolling_key_col_id,
185267
drop_null_groups=drop_null_groups,
186268
)
269+
270+
271+
def agg(input: str, op: agg_ops.AggregateOp) -> agg_expressions.Aggregation:
272+
if isinstance(op, agg_ops.UnaryAggregateOp):
273+
return agg_expressions.UnaryAggregation(op, ex.deref(input))
274+
else:
275+
assert isinstance(op, agg_ops.NullaryAggregateOp)
276+
return agg_expressions.NullaryAggregation(op)
Collapse file

‎tests/system/small/test_window.py‎

Copy file name to clipboardExpand all lines: tests/system/small/test_window.py
+69Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,75 @@ def test_dataframe_window_agg_ops(scalars_dfs, windowing, agg_op):
228228
pd.testing.assert_frame_equal(pd_result, bf_result, check_dtype=False)
229229

230230

231+
@pytest.mark.parametrize(
232+
("windowing"),
233+
[
234+
pytest.param(lambda x: x.expanding(), id="expanding"),
235+
pytest.param(lambda x: x.rolling(3, min_periods=3), id="rolling"),
236+
pytest.param(
237+
lambda x: x.groupby(level=0).rolling(3, min_periods=3), id="rollinggroupby"
238+
),
239+
pytest.param(
240+
lambda x: x.groupby("int64_too").expanding(min_periods=2),
241+
id="expandinggroupby",
242+
),
243+
],
244+
)
245+
@pytest.mark.parametrize(
246+
("func"),
247+
[
248+
pytest.param("sum", id="sum_by_name"),
249+
pytest.param(np.sum, id="sum_by_by_np"),
250+
pytest.param([np.sum, np.mean], id="list_of_funcs"),
251+
pytest.param(
252+
{"int64_col": np.sum, "float64_col": "mean"}, id="dict_of_single_funcs"
253+
),
254+
pytest.param(
255+
{"int64_col": np.sum, "float64_col": ["mean", np.max]},
256+
id="dict_of_lists_and_single_funcs",
257+
),
258+
],
259+
)
260+
def test_dataframe_window_agg_func(scalars_dfs, windowing, func):
261+
bf_df, pd_df = scalars_dfs
262+
target_columns = ["int64_too", "float64_col", "bool_col", "int64_col"]
263+
index_column = "bool_col"
264+
bf_df = bf_df[target_columns].set_index(index_column)
265+
pd_df = pd_df[target_columns].set_index(index_column)
266+
267+
bf_result = windowing(bf_df).agg(func).to_pandas()
268+
269+
pd_result = windowing(pd_df).agg(func)
270+
271+
pd.testing.assert_frame_equal(pd_result, bf_result, check_dtype=False)
272+
273+
274+
def test_series_window_agg_single_func(scalars_dfs):
275+
bf_df, pd_df = scalars_dfs
276+
index_column = "bool_col"
277+
bf_series = bf_df.set_index(index_column).int64_too
278+
pd_series = pd_df.set_index(index_column).int64_too
279+
280+
bf_result = bf_series.expanding().agg("sum").to_pandas()
281+
282+
pd_result = pd_series.expanding().agg("sum")
283+
284+
pd.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
285+
286+
287+
def test_series_window_agg_multi_func(scalars_dfs):
288+
bf_df, pd_df = scalars_dfs
289+
index_column = "bool_col"
290+
bf_series = bf_df.set_index(index_column).int64_too
291+
pd_series = pd_df.set_index(index_column).int64_too
292+
293+
bf_result = bf_series.expanding().agg(["sum", np.mean]).to_pandas()
294+
295+
pd_result = pd_series.expanding().agg(["sum", np.mean])
296+
297+
pd.testing.assert_frame_equal(pd_result, bf_result, check_dtype=False)
298+
299+
231300
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
232301
@pytest.mark.parametrize(
233302
"window", # skipped numpy timedelta because Pandas does not support it.
Collapse file

‎third_party/bigframes_vendored/pandas/core/window/rolling.py‎

Copy file name to clipboardExpand all lines: third_party/bigframes_vendored/pandas/core/window/rolling.py
+49Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,52 @@ def max(self):
3737
def min(self):
3838
"""Calculate the weighted window minimum."""
3939
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
40+
41+
def agg(self, func):
42+
"""
43+
Aggregate using one or more operations over the specified axis.
44+
45+
**Examples:**
46+
47+
>>> import bigframes.pandas as bpd
48+
49+
>>> df = bpd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
50+
>>> df
51+
A B C
52+
0 1 4 7
53+
1 2 5 8
54+
2 3 6 9
55+
<BLANKLINE>
56+
[3 rows x 3 columns]
57+
58+
>>> df.rolling(2).sum()
59+
A B C
60+
0 <NA> <NA> <NA>
61+
1 3 9 15
62+
2 5 11 17
63+
<BLANKLINE>
64+
[3 rows x 3 columns]
65+
66+
>>> df.rolling(2).agg({"A": "sum", "B": "min"})
67+
A B
68+
0 <NA> <NA>
69+
1 3 4
70+
2 5 5
71+
<BLANKLINE>
72+
[3 rows x 2 columns]
73+
74+
Args:
75+
func (function, str, list or dict):
76+
Function to use for aggregating the data.
77+
78+
Accepted combinations are:
79+
80+
- string function name
81+
- list of function names, e.g. ``['sum', 'mean']``
82+
- dict of axis labels -> function names or list of such.
83+
84+
Returns:
85+
Series or DataFrame
86+
87+
"""
88+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.