diff --git a/bigframes/core/rewrite/timedeltas.py b/bigframes/core/rewrite/timedeltas.py index d740b28d7d..990aca1f18 100644 --- a/bigframes/core/rewrite/timedeltas.py +++ b/bigframes/core/rewrite/timedeltas.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import dataclasses import functools import typing @@ -27,6 +29,14 @@ class _TypedExpr: expr: ex.Expression dtype: dtypes.Dtype + @classmethod + def create_op_expr( + cls, op: typing.Union[ops.ScalarOp, ops.RowOp], *inputs: _TypedExpr + ) -> _TypedExpr: + expr = op.as_expr(*tuple(x.expr for x in inputs)) # type: ignore + dtype = op.output_type(*tuple(x.dtype for x in inputs)) + return cls(expr, dtype) + def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNode: """ @@ -38,12 +48,27 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod (_rewrite_expressions(expr, root.schema).expr, column_id) for expr, column_id in root.assignments ) - root = nodes.ProjectionNode(root.child, updated_assignments) + return nodes.ProjectionNode(root.child, updated_assignments) + + if isinstance(root, nodes.FilterNode): + return nodes.FilterNode( + root.child, _rewrite_expressions(root.predicate, root.schema).expr + ) + + if isinstance(root, nodes.OrderByNode): + by = tuple(_rewrite_ordering_expr(x, root.schema) for x in root.by) + return nodes.OrderByNode(root.child, by) - # TODO(b/394354614): FilterByNode and OrderNode also contain expressions. Need to update them too. return root +def _rewrite_ordering_expr( + expr: nodes.OrderingExpression, schema: schema.ArraySchema +) -> nodes.OrderingExpression: + by = _rewrite_expressions(expr.scalar_expression, schema).expr + return nodes.OrderingExpression(by, expr.direction, expr.na_last) + + @functools.cache def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _TypedExpr: if isinstance(expr, ex.DerefOp): @@ -78,37 +103,23 @@ def _rewrite_op_expr( if isinstance(expr.op, ops.AddOp): return _rewrite_add_op(inputs[0], inputs[1]) - input_types = tuple(map(lambda x: x.dtype, inputs)) - return _TypedExpr(expr, expr.op.output_type(*input_types)) + return _TypedExpr.create_op_expr(expr.op, *inputs) def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: - result_op: ops.BinaryOp = ops.sub_op if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype): - result_op = ops.timestamp_diff_op + return _TypedExpr.create_op_expr(ops.timestamp_diff_op, left, right) - return _TypedExpr( - result_op.as_expr(left.expr, right.expr), - result_op.output_type(left.dtype, right.dtype), - ) + return _TypedExpr.create_op_expr(ops.sub_op, left, right) def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr: if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE: - return _TypedExpr( - ops.timestamp_add_op.as_expr(left.expr, right.expr), - ops.timestamp_add_op.output_type(left.dtype, right.dtype), - ) + return _TypedExpr.create_op_expr(ops.timestamp_add_op, left, right) if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype): # Re-arrange operands such that timestamp is always on the left and timedelta is # always on the right. - return _TypedExpr( - ops.timestamp_add_op.as_expr(right.expr, left.expr), - ops.timestamp_add_op.output_type(right.dtype, left.dtype), - ) + return _TypedExpr.create_op_expr(ops.timestamp_add_op, right, left) - return _TypedExpr( - ops.add_op.as_expr(left.expr, right.expr), - ops.add_op.output_type(left.dtype, right.dtype), - ) + return _TypedExpr.create_op_expr(ops.add_op, left, right) diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index 9083e13bc8..e4db904210 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -358,7 +358,7 @@ def is_comparable(type_: ExpressionType) -> bool: def is_orderable(type_: ExpressionType) -> bool: # On BQ side, ARRAY, STRUCT, GEOGRAPHY, JSON are not orderable - return type_ in _ORDERABLE_SIMPLE_TYPES + return type_ in _ORDERABLE_SIMPLE_TYPES or type_ is TIMEDELTA_DTYPE _CLUSTERABLE_SIMPLE_TYPES = set( diff --git a/tests/system/small/operations/test_timedeltas.py b/tests/system/small/operations/test_timedeltas.py index 6c44a62686..fe779a8524 100644 --- a/tests/system/small/operations/test_timedeltas.py +++ b/tests/system/small/operations/test_timedeltas.py @@ -14,6 +14,7 @@ import datetime +import operator import numpy as np import pandas as pd @@ -28,12 +29,23 @@ def temporal_dfs(session): "datetime_col": [ pd.Timestamp("2025-02-01 01:00:01"), pd.Timestamp("2019-01-02 02:00:00"), + pd.Timestamp("1997-01-01 19:00:00"), ], "timestamp_col": [ pd.Timestamp("2023-01-01 01:00:01", tz="UTC"), pd.Timestamp("2024-01-02 02:00:00", tz="UTC"), + pd.Timestamp("2005-03-05 02:00:00", tz="UTC"), + ], + "timedelta_col_1": [ + pd.Timedelta(3, "s"), + pd.Timedelta(-4, "d"), + pd.Timedelta(5, "h"), + ], + "timedelta_col_2": [ + pd.Timedelta(2, "s"), + pd.Timedelta(-4, "d"), + pd.Timedelta(6, "h"), ], - "timedelta_col": [pd.Timedelta(3, "s"), pd.Timedelta(-4, "d")], } ) @@ -53,10 +65,10 @@ def test_timestamp_add__ts_series_plus_td_series(temporal_dfs, column, pd_dtype) bf_df, pd_df = temporal_dfs actual_result = ( - (bf_df[column] + bf_df["timedelta_col"]).to_pandas().astype(pd_dtype) + (bf_df[column] + bf_df["timedelta_col_1"]).to_pandas().astype(pd_dtype) ) - expected_result = pd_df[column] + pd_df["timedelta_col"] + expected_result = pd_df[column] + pd_df["timedelta_col_1"] pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False ) @@ -94,10 +106,10 @@ def test_timestamp_add__td_series_plus_ts_series(temporal_dfs, column, pd_dtype) bf_df, pd_df = temporal_dfs actual_result = ( - (bf_df["timedelta_col"] + bf_df[column]).to_pandas().astype(pd_dtype) + (bf_df["timedelta_col_1"] + bf_df[column]).to_pandas().astype(pd_dtype) ) - expected_result = pd_df["timedelta_col"] + pd_df[column] + expected_result = pd_df["timedelta_col_1"] + pd_df[column] pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False ) @@ -120,10 +132,10 @@ def test_timestamp_add__ts_literal_plus_td_series(temporal_dfs): timestamp = pd.Timestamp("2025-01-01", tz="UTC") actual_result = ( - (timestamp + bf_df["timedelta_col"]).to_pandas().astype("datetime64[ns, UTC]") + (timestamp + bf_df["timedelta_col_1"]).to_pandas().astype("datetime64[ns, UTC]") ) - expected_result = timestamp + pd_df["timedelta_col"] + expected_result = timestamp + pd_df["timedelta_col_1"] pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False ) @@ -140,10 +152,10 @@ def test_timestamp_add_with_numpy_op(temporal_dfs, column, pd_dtype): bf_df, pd_df = temporal_dfs actual_result = ( - np.add(bf_df[column], bf_df["timedelta_col"]).to_pandas().astype(pd_dtype) + np.add(bf_df[column], bf_df["timedelta_col_1"]).to_pandas().astype(pd_dtype) ) - expected_result = np.add(pd_df[column], pd_df["timedelta_col"]) + expected_result = np.add(pd_df[column], pd_df["timedelta_col_1"]) pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False ) @@ -164,3 +176,105 @@ def test_timestamp_add_dataframes(temporal_dfs): pandas.testing.assert_frame_equal( actual_result, expected_result, check_index_type=False ) + + +@pytest.mark.parametrize( + "compare_func", + [ + pytest.param(operator.gt, id="gt"), + pytest.param(operator.ge, id="ge"), + pytest.param(operator.eq, id="eq"), + pytest.param(operator.ne, id="ne"), + pytest.param(operator.lt, id="lt"), + pytest.param(operator.le, id="le"), + ], +) +def test_timedelta_series_comparison(temporal_dfs, compare_func): + bf_df, pd_df = temporal_dfs + + actual_result = compare_func( + bf_df["timedelta_col_1"], bf_df["timedelta_col_2"] + ).to_pandas() + + expected_result = compare_func( + pd_df["timedelta_col_1"], pd_df["timedelta_col_2"] + ).astype("boolean") + pandas.testing.assert_series_equal( + actual_result, expected_result, check_index_type=False + ) + + +@pytest.mark.parametrize( + "compare_func", + [ + pytest.param(operator.gt, id="gt"), + pytest.param(operator.ge, id="ge"), + pytest.param(operator.eq, id="eq"), + pytest.param(operator.ne, id="ne"), + pytest.param(operator.lt, id="lt"), + pytest.param(operator.le, id="le"), + ], +) +def test_timedelta_series_and_literal_comparison(temporal_dfs, compare_func): + bf_df, pd_df = temporal_dfs + literal = pd.Timedelta(3, "s") + + actual_result = compare_func(literal, bf_df["timedelta_col_2"]).to_pandas() + + expected_result = compare_func(literal, pd_df["timedelta_col_2"]).astype("boolean") + pandas.testing.assert_series_equal( + actual_result, expected_result, check_index_type=False + ) + + +def test_timedelta_filtering(session): + pd_series = pd.Series( + [ + pd.Timestamp("2025-01-01 01:00:00"), + pd.Timestamp("2025-01-01 02:00:00"), + pd.Timestamp("2025-01-01 03:00:00"), + ] + ) + bf_series = session.read_pandas(pd_series) + timestamp = pd.Timestamp("2025-01-01, 00:00:01") + + actual_result = ( + bf_series[((bf_series - timestamp) > pd.Timedelta(1, "h"))] + .to_pandas() + .astype(" pd.Timedelta(1, "h")] + pandas.testing.assert_series_equal( + actual_result, expected_result, check_index_type=False + ) + + +def test_timedelta_ordering(session): + pd_df = pd.DataFrame( + { + "col_1": [ + pd.Timestamp("2025-01-01 01:00:00"), + pd.Timestamp("2025-01-01 02:00:00"), + pd.Timestamp("2025-01-01 03:00:00"), + ], + "col_2": [ + pd.Timestamp("2025-01-01 01:00:02"), + pd.Timestamp("2025-01-01 02:00:01"), + pd.Timestamp("2025-01-01 02:59:59"), + ], + } + ) + bf_df = session.read_pandas(pd_df) + + actual_result = ( + (bf_df["col_2"] - bf_df["col_1"]) + .sort_values() + .to_pandas() + .astype("timedelta64[ns]") + ) + + expected_result = (pd_df["col_2"] - pd_df["col_1"]).sort_values() + pandas.testing.assert_series_equal( + actual_result, expected_result, check_index_type=False + )