From 98a54e81ee476f3b74a7396c3965baaee91d8f71 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 11 Feb 2025 20:25:40 +0000 Subject: [PATCH 1/6] [WIP] support timedelta ordering and filtering --- bigframes/core/rewrite/timedeltas.py | 76 ++++++++++++++++------------ bigframes/dtypes.py | 2 +- 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/bigframes/core/rewrite/timedeltas.py b/bigframes/core/rewrite/timedeltas.py index d740b28d7d..4d47c1cc94 100644 --- a/bigframes/core/rewrite/timedeltas.py +++ b/bigframes/core/rewrite/timedeltas.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import dataclasses import functools import typing @@ -27,6 +29,14 @@ class _TypedExpr: expr: ex.Expression dtype: dtypes.Dtype + @classmethod + def create_op_expr( + cls, op: typing.Union[ops.ScalarOp, ops.RowOp], *inputs: _TypedExpr + ) -> _TypedExpr: + expr = op.as_expr(*tuple(x.expr for x in inputs)) # type: ignore + dtype = op.output_type(*tuple(x.dtype for x in inputs)) + return cls(expr, dtype) + def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNode: """ @@ -38,12 +48,27 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod (_rewrite_expressions(expr, root.schema).expr, column_id) for expr, column_id in root.assignments ) - root = nodes.ProjectionNode(root.child, updated_assignments) + return nodes.ProjectionNode(root.child, updated_assignments) + + if isinstance(root, nodes.FilterNode): + return nodes.FilterNode( + root.child, _rewrite_expressions(root.predicate, root.schema).expr + ) + + if isinstance(root, nodes.OrderByNode): + by = tuple(_rewrite_ordering_expr(x, root.schema) for x in root.by) + return nodes.OrderByNode(root.child, by) - # TODO(b/394354614): FilterByNode and OrderNode also contain expressions. Need to update them too. return root +def _rewrite_ordering_expr( + expr: nodes.OrderingExpression, schema: schema.ArraySchema +) -> nodes.OrderingExpression: + by = _rewrite_expressions(expr.scalar_expression, schema).expr + return nodes.OrderingExpression(by, expr.direction, expr.na_last) + + @functools.cache def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _TypedExpr: if isinstance(expr, ex.DerefOp): @@ -56,7 +81,7 @@ def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _Ty updated_inputs = tuple( map(lambda x: _rewrite_expressions(x, schema), expr.inputs) ) - return _rewrite_op_expr(expr, updated_inputs) + return _rewrite_op_expr(expr.op, updated_inputs) raise AssertionError(f"Unexpected expression type: {type(expr)}") @@ -69,46 +94,33 @@ def _rewrite_scalar_constant_expr(expr: ex.ScalarConstantExpression) -> _TypedEx return _TypedExpr(expr, expr.dtype) -def _rewrite_op_expr( - expr: ex.OpExpression, inputs: typing.Tuple[_TypedExpr, ...] -) -> _TypedExpr: - if isinstance(expr.op, ops.SubOp): - return _rewrite_sub_op(inputs[0], inputs[1]) - - if isinstance(expr.op, ops.AddOp): - return _rewrite_add_op(inputs[0], inputs[1]) +@functools.singledispatch +def _rewrite_op_expr(op: ops.ScalarOp, inputs: typing.Tuple[_TypedExpr, ...]) -> _TypedExpr: + return _TypedExpr.create_op_expr(op, *inputs) - input_types = tuple(map(lambda x: x.dtype, inputs)) - return _TypedExpr(expr, expr.op.output_type(*input_types)) +@_rewrite_op_expr.register +def _(op: ops.SubOp, inputs: typing.Tuple[_TypedExpr, ...]) -> _TypedExpr: + left = inputs[0] + right = inputs[1] -def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: result_op: ops.BinaryOp = ops.sub_op if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype): result_op = ops.timestamp_diff_op - return _TypedExpr( - result_op.as_expr(left.expr, right.expr), - result_op.output_type(left.dtype, right.dtype), - ) - + return _TypedExpr.create_op_expr(result_op, left, right) -def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: +@_rewrite_op_expr.register +def _(op: ops.AddOp, inputs: typing.Tuple[_TypedExpr, ...]) -> _TypedExpr: + left = inputs[0] + right = inputs[1] + if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE: - return _TypedExpr( - ops.timestamp_add_op.as_expr(left.expr, right.expr), - ops.timestamp_add_op.output_type(left.dtype, right.dtype), - ) + return _TypedExpr.create_op_expr(ops.timestamp_add_op, left, right) if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype): # Re-arrange operands such that timestamp is always on the left and timedelta is # always on the right. - return _TypedExpr( - ops.timestamp_add_op.as_expr(right.expr, left.expr), - ops.timestamp_add_op.output_type(right.dtype, left.dtype), - ) + return _TypedExpr.create_op_expr(ops.timestamp_add_op, right, left) - return _TypedExpr( - ops.add_op.as_expr(left.expr, right.expr), - ops.add_op.output_type(left.dtype, right.dtype), - ) + return _TypedExpr.create_op_expr(ops.add_op, left, right) diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index eed45e1dde..467f0a1996 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -357,7 +357,7 @@ def is_comparable(type_: ExpressionType) -> bool: def is_orderable(type_: ExpressionType) -> bool: # On BQ side, ARRAY, STRUCT, GEOGRAPHY, JSON are not orderable - return type_ in _ORDERABLE_SIMPLE_TYPES + return type_ in _ORDERABLE_SIMPLE_TYPES or type_ is TIMEDELTA_DTYPE _CLUSTERABLE_SIMPLE_TYPES = set( From 3e4e9e132917a03b0acc04202043f566a020e125 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 11 Feb 2025 21:08:34 +0000 Subject: [PATCH 2/6] chore: support comparison, ordering, and filtering for timedeltas --- bigframes/core/rewrite/timedeltas.py | 7 +- .../small/operations/test_timedeltas.py | 124 ++++++++++++++++-- 2 files changed, 120 insertions(+), 11 deletions(-) diff --git a/bigframes/core/rewrite/timedeltas.py b/bigframes/core/rewrite/timedeltas.py index 4d47c1cc94..ca92221b40 100644 --- a/bigframes/core/rewrite/timedeltas.py +++ b/bigframes/core/rewrite/timedeltas.py @@ -95,7 +95,9 @@ def _rewrite_scalar_constant_expr(expr: ex.ScalarConstantExpression) -> _TypedEx @functools.singledispatch -def _rewrite_op_expr(op: ops.ScalarOp, inputs: typing.Tuple[_TypedExpr, ...]) -> _TypedExpr: +def _rewrite_op_expr( + op: ops.ScalarOp, inputs: typing.Tuple[_TypedExpr, ...] +) -> _TypedExpr: return _TypedExpr.create_op_expr(op, *inputs) @@ -110,11 +112,12 @@ def _(op: ops.SubOp, inputs: typing.Tuple[_TypedExpr, ...]) -> _TypedExpr: return _TypedExpr.create_op_expr(result_op, left, right) + @_rewrite_op_expr.register def _(op: ops.AddOp, inputs: typing.Tuple[_TypedExpr, ...]) -> _TypedExpr: left = inputs[0] right = inputs[1] - + if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE: return _TypedExpr.create_op_expr(ops.timestamp_add_op, left, right) diff --git a/tests/system/small/operations/test_timedeltas.py b/tests/system/small/operations/test_timedeltas.py index 6c44a62686..375b0e9d51 100644 --- a/tests/system/small/operations/test_timedeltas.py +++ b/tests/system/small/operations/test_timedeltas.py @@ -28,12 +28,23 @@ def temporal_dfs(session): "datetime_col": [ pd.Timestamp("2025-02-01 01:00:01"), pd.Timestamp("2019-01-02 02:00:00"), + pd.Timestamp("1997-01-01 19:00:00"), ], "timestamp_col": [ pd.Timestamp("2023-01-01 01:00:01", tz="UTC"), pd.Timestamp("2024-01-02 02:00:00", tz="UTC"), + pd.Timestamp("2005-03-05 02:00:00", tz="UTC"), + ], + "timedelta_col_1": [ + pd.Timedelta(3, "s"), + pd.Timedelta(-4, "d"), + pd.Timedelta(5, "h"), + ], + "timedelta_col_2": [ + pd.Timedelta(2, "s"), + pd.Timedelta(-4, "d"), + pd.Timedelta(6, "h"), ], - "timedelta_col": [pd.Timedelta(3, "s"), pd.Timedelta(-4, "d")], } ) @@ -53,10 +64,10 @@ def test_timestamp_add__ts_series_plus_td_series(temporal_dfs, column, pd_dtype) bf_df, pd_df = temporal_dfs actual_result = ( - (bf_df[column] + bf_df["timedelta_col"]).to_pandas().astype(pd_dtype) + (bf_df[column] + bf_df["timedelta_col_1"]).to_pandas().astype(pd_dtype) ) - expected_result = pd_df[column] + pd_df["timedelta_col"] + expected_result = pd_df[column] + pd_df["timedelta_col_1"] pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False ) @@ -94,10 +105,10 @@ def test_timestamp_add__td_series_plus_ts_series(temporal_dfs, column, pd_dtype) bf_df, pd_df = temporal_dfs actual_result = ( - (bf_df["timedelta_col"] + bf_df[column]).to_pandas().astype(pd_dtype) + (bf_df["timedelta_col_1"] + bf_df[column]).to_pandas().astype(pd_dtype) ) - expected_result = pd_df["timedelta_col"] + pd_df[column] + expected_result = pd_df["timedelta_col_1"] + pd_df[column] pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False ) @@ -120,10 +131,10 @@ def test_timestamp_add__ts_literal_plus_td_series(temporal_dfs): timestamp = pd.Timestamp("2025-01-01", tz="UTC") actual_result = ( - (timestamp + bf_df["timedelta_col"]).to_pandas().astype("datetime64[ns, UTC]") + (timestamp + bf_df["timedelta_col_1"]).to_pandas().astype("datetime64[ns, UTC]") ) - expected_result = timestamp + pd_df["timedelta_col"] + expected_result = timestamp + pd_df["timedelta_col_1"] pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False ) @@ -140,10 +151,10 @@ def test_timestamp_add_with_numpy_op(temporal_dfs, column, pd_dtype): bf_df, pd_df = temporal_dfs actual_result = ( - np.add(bf_df[column], bf_df["timedelta_col"]).to_pandas().astype(pd_dtype) + np.add(bf_df[column], bf_df["timedelta_col_1"]).to_pandas().astype(pd_dtype) ) - expected_result = np.add(pd_df[column], pd_df["timedelta_col"]) + expected_result = np.add(pd_df[column], pd_df["timedelta_col_1"]) pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False ) @@ -164,3 +175,98 @@ def test_timestamp_add_dataframes(temporal_dfs): pandas.testing.assert_frame_equal( actual_result, expected_result, check_index_type=False ) + + +@pytest.mark.parametrize( + "compare_func", + [ + pytest.param(lambda x, y: x > y, id="gt"), + pytest.param(lambda x, y: x >= y, id="ge"), + pytest.param(lambda x, y: x == y, id="eq"), + pytest.param(lambda x, y: x < y, id="lt"), + pytest.param(lambda x, y: x <= y, id="le"), + ], +) +def test_timedelta_series_comparison(temporal_dfs, compare_func): + bf_df, pd_df = temporal_dfs + + actual_result = compare_func( + bf_df["timedelta_col_1"], bf_df["timedelta_col_2"] + ).to_pandas() + + expected_result = compare_func( + pd_df["timedelta_col_1"], pd_df["timedelta_col_2"] + ).astype("boolean") + pandas.testing.assert_series_equal( + actual_result, expected_result, check_index_type=False + ) + + +@pytest.mark.parametrize( + "compare_func", + [ + pytest.param(lambda x, y: x > y, id="gt"), + pytest.param(lambda x, y: x >= y, id="ge"), + pytest.param(lambda x, y: x == y, id="eq"), + pytest.param(lambda x, y: x < y, id="lt"), + pytest.param(lambda x, y: x <= y, id="le"), + ], +) +def test_timedelta_series_and_literal_comparison(temporal_dfs, compare_func): + bf_df, pd_df = temporal_dfs + literal = pd.Timedelta(3, "s") + + actual_result = compare_func(literal, bf_df["timedelta_col_2"]).to_pandas() + + expected_result = compare_func(literal, pd_df["timedelta_col_2"]).astype("boolean") + pandas.testing.assert_series_equal( + actual_result, expected_result, check_index_type=False + ) + + +def test_timedelta_filtering(session): + pd_series = pd.Series( + [ + pd.Timestamp("2025-01-01 01:00:00"), + pd.Timestamp("2025-01-01 02:00:00"), + pd.Timestamp("2025-01-01 03:00:00"), + ] + ) + bf_series = session.read_pandas(pd_series) + timestamp = pd.Timestamp("2025-01-01, 00:00:01") + + actual_result = ( + bf_series[((bf_series - timestamp) > pd.Timedelta(1, "h"))] + .to_pandas() + .astype(" pd.Timedelta(1, "h")] + pandas.testing.assert_series_equal( + actual_result, expected_result, check_index_type=False + ) + + +def test_timedelta_ordering(session): + pd_df = pd.DataFrame( + { + "col_1": [ + pd.Timestamp("2025-01-01 01:00:00"), + pd.Timestamp("2025-01-01 02:00:00"), + pd.Timestamp("2025-01-01 03:00:00"), + ], + "col_2": [ + pd.Timestamp("2025-01-01 01:00:02"), + pd.Timestamp("2025-01-01 02:00:01"), + pd.Timestamp("2025-01-01 02:59:59"), + ], + } + ) + bf_df = session.read_pandas(pd_df) + + actual_result = (bf_df['col_2'] - bf_df['col_1']).sort_values().to_pandas() + + expected_result = (pd_df['col_2'] - pd_df['col_1']).sort_values().astype('duration[us][pyarrow]') + pandas.testing.assert_series_equal( + actual_result, expected_result, check_index_type=False + ) From 51c61531b78deba2b41d1e20afe83b331aedd0c0 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 11 Feb 2025 21:15:38 +0000 Subject: [PATCH 3/6] fix format --- tests/system/small/operations/test_timedeltas.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/system/small/operations/test_timedeltas.py b/tests/system/small/operations/test_timedeltas.py index 375b0e9d51..97e39ab7a3 100644 --- a/tests/system/small/operations/test_timedeltas.py +++ b/tests/system/small/operations/test_timedeltas.py @@ -264,9 +264,11 @@ def test_timedelta_ordering(session): ) bf_df = session.read_pandas(pd_df) - actual_result = (bf_df['col_2'] - bf_df['col_1']).sort_values().to_pandas() + actual_result = (bf_df["col_2"] - bf_df["col_1"]).sort_values().to_pandas() - expected_result = (pd_df['col_2'] - pd_df['col_1']).sort_values().astype('duration[us][pyarrow]') + expected_result = ( + (pd_df["col_2"] - pd_df["col_1"]).sort_values().astype("duration[us][pyarrow]") + ) pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False ) From 07783955ed390669053f0bf467136e021e31006d Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 11 Feb 2025 21:38:16 +0000 Subject: [PATCH 4/6] some cleanups --- bigframes/core/rewrite/timedeltas.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/bigframes/core/rewrite/timedeltas.py b/bigframes/core/rewrite/timedeltas.py index ca92221b40..990aca1f18 100644 --- a/bigframes/core/rewrite/timedeltas.py +++ b/bigframes/core/rewrite/timedeltas.py @@ -81,7 +81,7 @@ def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _Ty updated_inputs = tuple( map(lambda x: _rewrite_expressions(x, schema), expr.inputs) ) - return _rewrite_op_expr(expr.op, updated_inputs) + return _rewrite_op_expr(expr, updated_inputs) raise AssertionError(f"Unexpected expression type: {type(expr)}") @@ -94,30 +94,26 @@ def _rewrite_scalar_constant_expr(expr: ex.ScalarConstantExpression) -> _TypedEx return _TypedExpr(expr, expr.dtype) -@functools.singledispatch def _rewrite_op_expr( - op: ops.ScalarOp, inputs: typing.Tuple[_TypedExpr, ...] + expr: ex.OpExpression, inputs: typing.Tuple[_TypedExpr, ...] ) -> _TypedExpr: - return _TypedExpr.create_op_expr(op, *inputs) + if isinstance(expr.op, ops.SubOp): + return _rewrite_sub_op(inputs[0], inputs[1]) + if isinstance(expr.op, ops.AddOp): + return _rewrite_add_op(inputs[0], inputs[1]) -@_rewrite_op_expr.register -def _(op: ops.SubOp, inputs: typing.Tuple[_TypedExpr, ...]) -> _TypedExpr: - left = inputs[0] - right = inputs[1] + return _TypedExpr.create_op_expr(expr.op, *inputs) - result_op: ops.BinaryOp = ops.sub_op - if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype): - result_op = ops.timestamp_diff_op - return _TypedExpr.create_op_expr(result_op, left, right) +def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: + if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype): + return _TypedExpr.create_op_expr(ops.timestamp_diff_op, left, right) + return _TypedExpr.create_op_expr(ops.sub_op, left, right) -@_rewrite_op_expr.register -def _(op: ops.AddOp, inputs: typing.Tuple[_TypedExpr, ...]) -> _TypedExpr: - left = inputs[0] - right = inputs[1] +def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE: return _TypedExpr.create_op_expr(ops.timestamp_add_op, left, right) From 1533152052eefa063174e1f325126cd4de234cd9 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 11 Feb 2025 21:51:58 +0000 Subject: [PATCH 5/6] use operator package for testing --- .../small/operations/test_timedeltas.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/system/small/operations/test_timedeltas.py b/tests/system/small/operations/test_timedeltas.py index 97e39ab7a3..18be9d383e 100644 --- a/tests/system/small/operations/test_timedeltas.py +++ b/tests/system/small/operations/test_timedeltas.py @@ -14,6 +14,7 @@ import datetime +import operator import numpy as np import pandas as pd @@ -180,11 +181,12 @@ def test_timestamp_add_dataframes(temporal_dfs): @pytest.mark.parametrize( "compare_func", [ - pytest.param(lambda x, y: x > y, id="gt"), - pytest.param(lambda x, y: x >= y, id="ge"), - pytest.param(lambda x, y: x == y, id="eq"), - pytest.param(lambda x, y: x < y, id="lt"), - pytest.param(lambda x, y: x <= y, id="le"), + pytest.param(operator.gt, id="gt"), + pytest.param(operator.ge, id="ge"), + pytest.param(operator.eq, id="eq"), + pytest.param(operator.ne, id="ne"), + pytest.param(operator.lt, id="lt"), + pytest.param(operator.le, id="le"), ], ) def test_timedelta_series_comparison(temporal_dfs, compare_func): @@ -205,11 +207,12 @@ def test_timedelta_series_comparison(temporal_dfs, compare_func): @pytest.mark.parametrize( "compare_func", [ - pytest.param(lambda x, y: x > y, id="gt"), - pytest.param(lambda x, y: x >= y, id="ge"), - pytest.param(lambda x, y: x == y, id="eq"), - pytest.param(lambda x, y: x < y, id="lt"), - pytest.param(lambda x, y: x <= y, id="le"), + pytest.param(operator.gt, id="gt"), + pytest.param(operator.ge, id="ge"), + pytest.param(operator.eq, id="eq"), + pytest.param(operator.ne, id="ne"), + pytest.param(operator.lt, id="lt"), + pytest.param(operator.le, id="le"), ], ) def test_timedelta_series_and_literal_comparison(temporal_dfs, compare_func): From 979c7cba3f4a8c7cb3733729cb8dff7b5dd748fc Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 11 Feb 2025 22:45:33 +0000 Subject: [PATCH 6/6] fix test error --- tests/system/small/operations/test_timedeltas.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/system/small/operations/test_timedeltas.py b/tests/system/small/operations/test_timedeltas.py index 18be9d383e..fe779a8524 100644 --- a/tests/system/small/operations/test_timedeltas.py +++ b/tests/system/small/operations/test_timedeltas.py @@ -267,11 +267,14 @@ def test_timedelta_ordering(session): ) bf_df = session.read_pandas(pd_df) - actual_result = (bf_df["col_2"] - bf_df["col_1"]).sort_values().to_pandas() - - expected_result = ( - (pd_df["col_2"] - pd_df["col_1"]).sort_values().astype("duration[us][pyarrow]") + actual_result = ( + (bf_df["col_2"] - bf_df["col_1"]) + .sort_values() + .to_pandas() + .astype("timedelta64[ns]") ) + + expected_result = (pd_df["col_2"] - pd_df["col_1"]).sort_values() pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False )