Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

fix: Fix miscasting issues with case_when #1003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions 9 bigframes/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import bigframes.operations.aggregations as agg_ops


def const(value: typing.Hashable, dtype: dtypes.ExpressionType = None) -> Expression:
def const(
value: typing.Hashable, dtype: dtypes.ExpressionType = None
) -> ScalarConstantExpression:
return ScalarConstantExpression(value, dtype or dtypes.infer_literal_type(value))


Expand Down Expand Up @@ -141,6 +143,9 @@ class ScalarConstantExpression(Expression):
def is_const(self) -> bool:
return True

def rename(self, name_mapping: Mapping[str, str]) -> ScalarConstantExpression:
return self

def output_type(
self, input_types: dict[str, bigframes.dtypes.Dtype]
) -> dtypes.ExpressionType:
Expand All @@ -167,7 +172,7 @@ class UnboundVariableExpression(Expression):
def unbound_variables(self) -> typing.Tuple[str, ...]:
return (self.id,)

def rename(self, name_mapping: Mapping[str, str]) -> Expression:
def rename(self, name_mapping: Mapping[str, str]) -> UnboundVariableExpression:
if self.id in name_mapping:
return UnboundVariableExpression(name_mapping[self.id])
else:
Expand Down
74 changes: 57 additions & 17 deletions 74 bigframes/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import typing
from typing import List, Sequence
from typing import List, Sequence, Union

import bigframes_vendored.constants as constants
import bigframes_vendored.pandas.pandas._typing as vendored_pandas_typing
Expand Down Expand Up @@ -180,9 +180,10 @@ def _apply_binary_op(
(self_col, other_col, block) = self._align(other_series, how=alignment)

name = self._name
# Drop name if both objects have name attr, but they don't match
if (
hasattr(other, "name")
and other.name != self._name
and other_series.name != self._name
and alignment == "outer"
):
name = None
Expand All @@ -208,41 +209,78 @@ def _apply_nary_op(
ignore_self=False,
):
"""Applies an n-ary operator to the series and others."""
values, block = self._align_n(others, ignore_self=ignore_self)
block, result_id = block.apply_nary_op(
values,
op,
self._name,
values, block = self._align_n(
others, ignore_self=ignore_self, cast_scalars=False
)
block, result_id = block.project_expr(op.as_expr(*values))
return series.Series(block.select_column(result_id))

def _apply_binary_aggregation(
self, other: series.Series, stat: agg_ops.BinaryAggregateOp
) -> float:
(left, right, block) = self._align(other, how="outer")
assert isinstance(left, ex.UnboundVariableExpression)
assert isinstance(right, ex.UnboundVariableExpression)
return block.get_binary_stat(left.id, right.id, stat)

AlignedExprT = Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression]

return block.get_binary_stat(left, right, stat)
@typing.overload
def _align(
self, other: series.Series, how="outer"
) -> tuple[
ex.UnboundVariableExpression,
ex.UnboundVariableExpression,
blocks.Block,
]:
...

def _align(self, other: series.Series, how="outer") -> tuple[str, str, blocks.Block]: # type: ignore
@typing.overload
def _align(
self, other: typing.Union[series.Series, scalars.Scalar], how="outer"
) -> tuple[ex.UnboundVariableExpression, AlignedExprT, blocks.Block,]:
...

def _align(
self, other: typing.Union[series.Series, scalars.Scalar], how="outer"
) -> tuple[ex.UnboundVariableExpression, AlignedExprT, blocks.Block,]:
"""Aligns the series value with another scalar or series object. Returns new left column id, right column id and joined tabled expression."""
values, block = self._align_n(
[
other,
],
how,
)
return (values[0], values[1], block)
return (typing.cast(ex.UnboundVariableExpression, values[0]), values[1], block)

def _align3(self, other1: series.Series | scalars.Scalar, other2: series.Series | scalars.Scalar, how="left") -> tuple[ex.UnboundVariableExpression, AlignedExprT, AlignedExprT, blocks.Block]: # type: ignore
"""Aligns the series value with 2 other scalars or series objects. Returns new values and joined tabled expression."""
values, index = self._align_n([other1, other2], how)
return (
typing.cast(ex.UnboundVariableExpression, values[0]),
values[1],
values[2],
index,
)

def _align_n(
self,
others: typing.Sequence[typing.Union[series.Series, scalars.Scalar]],
how="outer",
ignore_self=False,
) -> tuple[typing.Sequence[str], blocks.Block]:
cast_scalars: bool = True,
) -> tuple[
typing.Sequence[
Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression]
],
blocks.Block,
]:
if ignore_self:
value_ids: List[str] = []
value_ids: List[
Union[ex.ScalarConstantExpression, ex.UnboundVariableExpression]
] = []
else:
value_ids = [self._value_column]
value_ids = [ex.free_var(self._value_column)]

block = self._block
for other in others:
Expand All @@ -252,14 +290,16 @@ def _align_n(
get_column_right,
) = block.join(other._block, how=how)
value_ids = [
*[get_column_left[value] for value in value_ids],
get_column_right[other._value_column],
*[value.rename(get_column_left) for value in value_ids],
ex.free_var(get_column_right[other._value_column]),
]
else:
# Will throw if can't interpret as scalar.
dtype = typing.cast(bigframes.dtypes.Dtype, self._dtype)
block, constant_col_id = block.create_constant(other, dtype=dtype)
value_ids = [*value_ids, constant_col_id]
value_ids = [
*value_ids,
ex.const(other, dtype=dtype if cast_scalars else None),
]
return (value_ids, block)

def _throw_if_null_index(self, opname: str):
Expand Down
33 changes: 9 additions & 24 deletions 33 bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,23 +445,13 @@ def between(self, left, right, inclusive="both"):
)

def case_when(self, caselist) -> Series:
cases = list(itertools.chain(*caselist, (True, self)))
return self._apply_nary_op(
ops.case_when_op,
tuple(
itertools.chain(
itertools.chain(*caselist),
# Fallback to current value if no other matches.
(
# We make a Series with a constant value to avoid casts to
# types other than boolean.
Series(True, index=self.index, dtype=pandas.BooleanDtype()),
self,
),
),
),
cases,
# Self is already included in "others".
ignore_self=True,
)
).rename(self.name)

@validations.requires_ordering()
def cumsum(self) -> Series:
Expand Down Expand Up @@ -1116,8 +1106,8 @@ def ne(self, other: object) -> Series:

def where(self, cond, other=None):
value_id, cond_id, other_id, block = self._align3(cond, other)
block, result_id = block.apply_ternary_op(
value_id, cond_id, other_id, ops.where_op
block, result_id = block.project_expr(
ops.where_op.as_expr(value_id, cond_id, other_id)
)
return Series(block.select_column(result_id).with_column_labels([self.name]))

Expand All @@ -1129,8 +1119,8 @@ def clip(self, lower, upper):
if upper is None:
return self._apply_binary_op(lower, ops.maximum_op, alignment="left")
value_id, lower_id, upper_id, block = self._align3(lower, upper)
block, result_id = block.apply_ternary_op(
value_id, lower_id, upper_id, ops.clip_op
block, result_id = block.project_expr(
ops.clip_op.as_expr(value_id, lower_id, upper_id),
)
return Series(block.select_column(result_id).with_column_labels([self.name]))

Expand Down Expand Up @@ -1242,8 +1232,8 @@ def __getitem__(self, indexer):
return self.iloc[indexer]
if isinstance(indexer, Series):
(left, right, block) = self._align(indexer, "left")
block = block.filter_by_id(right)
block = block.select_column(left)
block = block.filter(right)
block = block.select_column(left.id)
return Series(block)
return self.loc[indexer]

Expand All @@ -1262,11 +1252,6 @@ def __getattr__(self, key: str):
else:
raise AttributeError(key)

def _align3(self, other1: Series | scalars.Scalar, other2: Series | scalars.Scalar, how="left") -> tuple[str, str, str, blocks.Block]: # type: ignore
"""Aligns the series value with 2 other scalars or series objects. Returns new values and joined tabled expression."""
values, index = self._align_n([other1, other2], how)
return (values[0], values[1], values[2], index)

def _apply_aggregation(
self, op: agg_ops.UnaryAggregateOp | agg_ops.NullaryAggregateOp
) -> Any:
Expand Down
13 changes: 8 additions & 5 deletions 13 tests/system/small/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2709,27 +2709,30 @@ def test_between(scalars_df_index, scalars_pandas_df_index, left, right, inclusi
)


def test_case_when(scalars_df_index, scalars_pandas_df_index):
def test_series_case_when(scalars_dfs_maybe_ordered):
pytest.importorskip(
"pandas",
minversion="2.2.0",
reason="case_when added in pandas 2.2.0",
)
scalars_df, scalars_pandas_df = scalars_dfs_maybe_ordered

bf_series = scalars_df_index["int64_col"]
pd_series = scalars_pandas_df_index["int64_col"]
bf_series = scalars_df["int64_col"]
pd_series = scalars_pandas_df["int64_col"]

# TODO(tswast): pandas case_when appears to assume True when a value is
# null. I suspect this should be considered a bug in pandas.
bf_result = bf_series.case_when(
[
((bf_series > 100).fillna(True), 1000),
((bf_series > 100).fillna(True), bf_series - 1),
((bf_series > 0).fillna(True), pd.NA),
((bf_series < -100).fillna(True), -1000),
]
).to_pandas()
pd_result = pd_series.case_when(
[
(pd_series > 100, 1000),
(pd_series > 100, pd_series - 1),
(pd_series > 0, pd.NA),
(pd_series < -100, -1000),
]
)
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.