diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index bbd23b689c..4779e92cde 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -25,7 +25,9 @@ import bigframes.operations.aggregations as agg_ops -def const(value: typing.Hashable, dtype: dtypes.ExpressionType = None) -> Expression: +def const( + value: typing.Hashable, dtype: dtypes.ExpressionType = None +) -> ScalarConstantExpression: return ScalarConstantExpression(value, dtype or dtypes.infer_literal_type(value)) @@ -141,6 +143,9 @@ class ScalarConstantExpression(Expression): def is_const(self) -> bool: return True + def rename(self, name_mapping: Mapping[str, str]) -> ScalarConstantExpression: + return self + def output_type( self, input_types: dict[str, bigframes.dtypes.Dtype] ) -> dtypes.ExpressionType: @@ -167,7 +172,7 @@ class UnboundVariableExpression(Expression): def unbound_variables(self) -> typing.Tuple[str, ...]: return (self.id,) - def rename(self, name_mapping: Mapping[str, str]) -> Expression: + def rename(self, name_mapping: Mapping[str, str]) -> UnboundVariableExpression: if self.id in name_mapping: return UnboundVariableExpression(name_mapping[self.id]) else: diff --git a/bigframes/operations/base.py b/bigframes/operations/base.py index 68f46baded..f9a6a87b7a 100644 --- a/bigframes/operations/base.py +++ b/bigframes/operations/base.py @@ -15,7 +15,7 @@ from __future__ import annotations import typing -from typing import List, Sequence +from typing import List, Sequence, Union import bigframes_vendored.constants as constants import bigframes_vendored.pandas.pandas._typing as vendored_pandas_typing @@ -180,9 +180,10 @@ def _apply_binary_op( (self_col, other_col, block) = self._align(other_series, how=alignment) name = self._name + # Drop name if both objects have name attr, but they don't match if ( hasattr(other, "name") - and other.name != self._name + and other_series.name != self._name and alignment == "outer" ): name = None @@ -208,22 +209,41 @@ def _apply_nary_op( ignore_self=False, ): """Applies an n-ary operator to the series and others.""" - values, block = self._align_n(others, ignore_self=ignore_self) - block, result_id = block.apply_nary_op( - values, - op, - self._name, + values, block = self._align_n( + others, ignore_self=ignore_self, cast_scalars=False ) + block, result_id = block.project_expr(op.as_expr(*values)) return series.Series(block.select_column(result_id)) def _apply_binary_aggregation( self, other: series.Series, stat: agg_ops.BinaryAggregateOp ) -> float: (left, right, block) = self._align(other, how="outer") + assert isinstance(left, ex.UnboundVariableExpression) + assert isinstance(right, ex.UnboundVariableExpression) + return block.get_binary_stat(left.id, right.id, stat) + + AlignedExprT = Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression] - return block.get_binary_stat(left, right, stat) + @typing.overload + def _align( + self, other: series.Series, how="outer" + ) -> tuple[ + ex.UnboundVariableExpression, + ex.UnboundVariableExpression, + blocks.Block, + ]: + ... - def _align(self, other: series.Series, how="outer") -> tuple[str, str, blocks.Block]: # type: ignore + @typing.overload + def _align( + self, other: typing.Union[series.Series, scalars.Scalar], how="outer" + ) -> tuple[ex.UnboundVariableExpression, AlignedExprT, blocks.Block,]: + ... + + def _align( + self, other: typing.Union[series.Series, scalars.Scalar], how="outer" + ) -> tuple[ex.UnboundVariableExpression, AlignedExprT, blocks.Block,]: """Aligns the series value with another scalar or series object. Returns new left column id, right column id and joined tabled expression.""" values, block = self._align_n( [ @@ -231,18 +251,36 @@ def _align(self, other: series.Series, how="outer") -> tuple[str, str, blocks.Bl ], how, ) - return (values[0], values[1], block) + return (typing.cast(ex.UnboundVariableExpression, values[0]), values[1], block) + + def _align3(self, other1: series.Series | scalars.Scalar, other2: series.Series | scalars.Scalar, how="left") -> tuple[ex.UnboundVariableExpression, AlignedExprT, AlignedExprT, blocks.Block]: # type: ignore + """Aligns the series value with 2 other scalars or series objects. Returns new values and joined tabled expression.""" + values, index = self._align_n([other1, other2], how) + return ( + typing.cast(ex.UnboundVariableExpression, values[0]), + values[1], + values[2], + index, + ) def _align_n( self, others: typing.Sequence[typing.Union[series.Series, scalars.Scalar]], how="outer", ignore_self=False, - ) -> tuple[typing.Sequence[str], blocks.Block]: + cast_scalars: bool = True, + ) -> tuple[ + typing.Sequence[ + Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression] + ], + blocks.Block, + ]: if ignore_self: - value_ids: List[str] = [] + value_ids: List[ + Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression] + ] = [] else: - value_ids = [self._value_column] + value_ids = [ex.free_var(self._value_column)] block = self._block for other in others: @@ -252,14 +290,16 @@ def _align_n( get_column_right, ) = block.join(other._block, how=how) value_ids = [ - *[get_column_left[value] for value in value_ids], - get_column_right[other._value_column], + *[value.rename(get_column_left) for value in value_ids], + ex.free_var(get_column_right[other._value_column]), ] else: # Will throw if can't interpret as scalar. dtype = typing.cast(bigframes.dtypes.Dtype, self._dtype) - block, constant_col_id = block.create_constant(other, dtype=dtype) - value_ids = [*value_ids, constant_col_id] + value_ids = [ + *value_ids, + ex.const(other, dtype=dtype if cast_scalars else None), + ] return (value_ids, block) def _throw_if_null_index(self, opname: str): diff --git a/bigframes/series.py b/bigframes/series.py index 3a75ab9ccc..82fb6c5089 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -445,23 +445,13 @@ def between(self, left, right, inclusive="both"): ) def case_when(self, caselist) -> Series: + cases = list(itertools.chain(*caselist, (True, self))) return self._apply_nary_op( ops.case_when_op, - tuple( - itertools.chain( - itertools.chain(*caselist), - # Fallback to current value if no other matches. - ( - # We make a Series with a constant value to avoid casts to - # types other than boolean. - Series(True, index=self.index, dtype=pandas.BooleanDtype()), - self, - ), - ), - ), + cases, # Self is already included in "others". ignore_self=True, - ) + ).rename(self.name) @validations.requires_ordering() def cumsum(self) -> Series: @@ -1116,8 +1106,8 @@ def ne(self, other: object) -> Series: def where(self, cond, other=None): value_id, cond_id, other_id, block = self._align3(cond, other) - block, result_id = block.apply_ternary_op( - value_id, cond_id, other_id, ops.where_op + block, result_id = block.project_expr( + ops.where_op.as_expr(value_id, cond_id, other_id) ) return Series(block.select_column(result_id).with_column_labels([self.name])) @@ -1129,8 +1119,8 @@ def clip(self, lower, upper): if upper is None: return self._apply_binary_op(lower, ops.maximum_op, alignment="left") value_id, lower_id, upper_id, block = self._align3(lower, upper) - block, result_id = block.apply_ternary_op( - value_id, lower_id, upper_id, ops.clip_op + block, result_id = block.project_expr( + ops.clip_op.as_expr(value_id, lower_id, upper_id), ) return Series(block.select_column(result_id).with_column_labels([self.name])) @@ -1242,8 +1232,8 @@ def __getitem__(self, indexer): return self.iloc[indexer] if isinstance(indexer, Series): (left, right, block) = self._align(indexer, "left") - block = block.filter_by_id(right) - block = block.select_column(left) + block = block.filter(right) + block = block.select_column(left.id) return Series(block) return self.loc[indexer] @@ -1262,11 +1252,6 @@ def __getattr__(self, key: str): else: raise AttributeError(key) - def _align3(self, other1: Series | scalars.Scalar, other2: Series | scalars.Scalar, how="left") -> tuple[str, str, str, blocks.Block]: # type: ignore - """Aligns the series value with 2 other scalars or series objects. Returns new values and joined tabled expression.""" - values, index = self._align_n([other1, other2], how) - return (values[0], values[1], values[2], index) - def _apply_aggregation( self, op: agg_ops.UnaryAggregateOp | agg_ops.NullaryAggregateOp ) -> Any: diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index b8f7926aec..793a4062c5 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -2709,27 +2709,30 @@ def test_between(scalars_df_index, scalars_pandas_df_index, left, right, inclusi ) -def test_case_when(scalars_df_index, scalars_pandas_df_index): +def test_series_case_when(scalars_dfs_maybe_ordered): pytest.importorskip( "pandas", minversion="2.2.0", reason="case_when added in pandas 2.2.0", ) + scalars_df, scalars_pandas_df = scalars_dfs_maybe_ordered - bf_series = scalars_df_index["int64_col"] - pd_series = scalars_pandas_df_index["int64_col"] + bf_series = scalars_df["int64_col"] + pd_series = scalars_pandas_df["int64_col"] # TODO(tswast): pandas case_when appears to assume True when a value is # null. I suspect this should be considered a bug in pandas. bf_result = bf_series.case_when( [ - ((bf_series > 100).fillna(True), 1000), + ((bf_series > 100).fillna(True), bf_series - 1), + ((bf_series > 0).fillna(True), pd.NA), ((bf_series < -100).fillna(True), -1000), ] ).to_pandas() pd_result = pd_series.case_when( [ - (pd_series > 100, 1000), + (pd_series > 100, pd_series - 1), + (pd_series > 0, pd.NA), (pd_series < -100, -1000), ] )