Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

refactor: simplify filter and join nodes #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions 33 bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@
from dataclasses import dataclass
import io
import typing
from typing import Iterable, Literal, Sequence
from typing import Iterable, Sequence

import ibis.expr.types as ibis_types
import pandas

import bigframes.core.compile as compiling
import bigframes.core.expression as ex
import bigframes.core.guid
import bigframes.core.join_def as join_def
import bigframes.core.nodes as nodes
from bigframes.core.ordering import OrderingColumnReference
import bigframes.core.ordering as orderings
import bigframes.core.utils
from bigframes.core.window_spec import WindowSpec
import bigframes.dtypes
import bigframes.operations as ops
import bigframes.operations.aggregations as agg_ops
import bigframes.session._io.bigquery

Expand Down Expand Up @@ -114,13 +116,15 @@ def row_count(self) -> ArrayValue:
return ArrayValue(nodes.RowCountNode(child=self.node))

# Operations
def filter(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:
def filter_by_id(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:
"""Filter the table on a given expression, the predicate must be a boolean series aligned with the table expression."""
return ArrayValue(
nodes.FilterNode(
child=self.node, predicate_id=predicate_id, keep_null=keep_null
)
)
predicate = ex.free_var(predicate_id)
if keep_null:
predicate = ops.fillna_op.as_expr(predicate, ex.const(True))
return self.filter(predicate)

def filter(self, predicate: ex.Expression):
return ArrayValue(nodes.FilterNode(child=self.node, predicate=predicate))

def order_by(self, by: Sequence[OrderingColumnReference]) -> ArrayValue:
return ArrayValue(nodes.OrderByNode(child=self.node, by=tuple(by)))
Expand Down Expand Up @@ -356,26 +360,15 @@ def unpivot(

def join(
self,
self_column_ids: typing.Sequence[str],
other: ArrayValue,
other_column_ids: typing.Sequence[str],
*,
how: Literal[
"inner",
"left",
"outer",
"right",
"cross",
],
join_def: join_def.JoinDefinition,
allow_row_identity_join: bool = True,
):
return ArrayValue(
nodes.JoinNode(
left_child=self.node,
right_child=other.node,
left_column_ids=tuple(self_column_ids),
right_column_ids=tuple(other_column_ids),
how=how,
join=join_def,
allow_row_identity_join=allow_row_identity_join,
)
)
Expand Down
39 changes: 29 additions & 10 deletions 39 bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import bigframes.core.expression as ex
import bigframes.core.guid as guid
import bigframes.core.indexes as indexes
import bigframes.core.joins.name_resolution as join_names
import bigframes.core.join_def as join_defs
import bigframes.core.ordering as ordering
import bigframes.core.utils
import bigframes.core.utils as utils
Expand Down Expand Up @@ -826,7 +826,7 @@ def assign_label(self, column_id: str, new_label: Label) -> Block:

def filter(self, column_id: str, keep_null: bool = False):
return Block(
self._expr.filter(column_id, keep_null),
self._expr.filter_by_id(column_id, keep_null),
index_columns=self.index_columns,
column_labels=self.column_labels,
index_labels=self.index.names,
Expand Down Expand Up @@ -1542,19 +1542,38 @@ def merge(
sort: bool,
suffixes: tuple[str, str] = ("_x", "_y"),
) -> Block:
joined_expr = self.expr.join(
left_join_ids,
other.expr,
right_join_ids,
how=how,
)
get_column_left, get_column_right = join_names.JOIN_NAME_REMAPPER(
self.expr.column_ids, other.expr.column_ids
left_mappings = [
join_defs.JoinColumnMapping(
source_table=join_defs.JoinSide.LEFT,
source_id=id,
destination_id=guid.generate_guid(),
)
for id in self.expr.column_ids
]
right_mappings = [
join_defs.JoinColumnMapping(
source_table=join_defs.JoinSide.RIGHT,
source_id=id,
destination_id=guid.generate_guid(),
)
for id in other.expr.column_ids
]

join_def = join_defs.JoinDefinition(
conditions=tuple(
join_defs.JoinCondition(left, right)
for left, right in zip(left_join_ids, right_join_ids)
),
mappings=(*left_mappings, *right_mappings),
type=how,
)
joined_expr = self.expr.join(other.expr, join_def=join_def)
result_columns = []
matching_join_labels = []

coalesced_ids = []
get_column_left = join_def.get_left_mapping()
get_column_right = join_def.get_right_mapping()
for left_id, right_id in zip(left_join_ids, right_join_ids):
coalesced_id = guid.generate_guid()
joined_expr = joined_expr.project_to_id(
Expand Down
32 changes: 8 additions & 24 deletions 32 bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def _reduced_predicate(self) -> typing.Optional[ibis_types.BooleanValue]:
)

@abc.abstractmethod
def filter(self: T, predicate_id: str, keep_null: bool = False) -> T:
"""Filter the table on a given expression, the predicate must be a boolean series aligned with the table expression."""
def filter(self: T, predicate: ex.Expression) -> T:
"""Filter the table on a given expression, the predicate must be a boolean expression."""
...

@abc.abstractmethod
Expand Down Expand Up @@ -305,17 +305,9 @@ def _to_ibis_expr(
table = table.filter(ibis.random() < ibis.literal(fraction))
return table

def filter(self, predicate_id: str, keep_null: bool = False) -> UnorderedIR:
condition = typing.cast(
ibis_types.BooleanValue, self._get_ibis_column(predicate_id)
)
if keep_null:
condition = typing.cast(
ibis_types.BooleanValue,
condition.fillna(
typing.cast(ibis_types.BooleanScalar, ibis_types.literal(True))
),
)
def filter(self, predicate: ex.Expression) -> UnorderedIR:
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
condition = op_compiler.compile_expression(predicate, bindings)
return self._filter(condition)

def _filter(self, predicate_value: ibis_types.BooleanValue) -> UnorderedIR:
Expand Down Expand Up @@ -1140,17 +1132,9 @@ def _to_ibis_expr(
table = table.filter(ibis.random() < ibis.literal(fraction))
return table

def filter(self, predicate_id: str, keep_null: bool = False) -> OrderedIR:
condition = typing.cast(
ibis_types.BooleanValue, self._get_ibis_column(predicate_id)
)
if keep_null:
condition = typing.cast(
ibis_types.BooleanValue,
condition.fillna(
typing.cast(ibis_types.BooleanScalar, ibis_types.literal(True))
),
)
def filter(self, predicate: ex.Expression) -> OrderedIR:
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
condition = op_compiler.compile_expression(predicate, bindings)
return self._filter(condition)

def _filter(self, predicate_value: ibis_types.BooleanValue) -> OrderedIR:
Expand Down
18 changes: 7 additions & 11 deletions 18 bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,18 @@ def compile_join(node: nodes.JoinNode, ordered: bool = True):
left_ordered = compile_ordered(node.left_child)
right_ordered = compile_ordered(node.right_child)
return bigframes.core.compile.single_column.join_by_column_ordered(
left_ordered,
node.left_column_ids,
right_ordered,
node.right_column_ids,
how=node.how,
left=left_ordered,
right=right_ordered,
join=node.join,
allow_row_identity_join=node.allow_row_identity_join,
)
else:
left_unordered = compile_unordered(node.left_child)
right_unordered = compile_unordered(node.right_child)
return bigframes.core.compile.single_column.join_by_column_unordered(
left_unordered,
node.left_column_ids,
right_unordered,
node.right_column_ids,
how=node.how,
left=left_unordered,
right=right_unordered,
join=node.join,
allow_row_identity_join=node.allow_row_identity_join,
)

Expand Down Expand Up @@ -113,7 +109,7 @@ def compile_promote_offsets(node: nodes.PromoteOffsetsNode, ordered: bool = True

@_compile_node.register
def compile_filter(node: nodes.FilterNode, ordered: bool = True):
return compile_node(node.child, ordered).filter(node.predicate_id, node.keep_null)
return compile_node(node.child, ordered).filter(node.predicate)


@_compile_node.register
Expand Down
41 changes: 23 additions & 18 deletions 41 bigframes/core/compile/row_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import bigframes.constants as constants
import bigframes.core.compile.compiled as compiled
import bigframes.core.join_def as join_def
import bigframes.core.joins as joining
import bigframes.core.ordering as orderings

Expand All @@ -33,11 +34,10 @@
def join_by_row_identity_unordered(
left: compiled.UnorderedIR,
right: compiled.UnorderedIR,
*,
how: str,
join_def: join_def.JoinDefinition,
) -> compiled.UnorderedIR:
"""Compute join when we are joining by row identity not a specific column."""
if how not in SUPPORTED_ROW_IDENTITY_HOW:
if join_def.type not in SUPPORTED_ROW_IDENTITY_HOW:
raise NotImplementedError(
f"Only how='outer','left','inner' currently supported. {constants.FEEDBACK_LINK}"
)
Expand All @@ -60,17 +60,20 @@ def join_by_row_identity_unordered(
combined_predicates = []
if left_predicates or right_predicates:
joined_predicates = _join_predicates(
left_predicates, right_predicates, join_type=how
left_predicates, right_predicates, join_type=join_def.type
)
combined_predicates = list(joined_predicates) # builder expects mutable list

left_mask = left_relative_predicates if how in ["right", "outer"] else None
right_mask = right_relative_predicates if how in ["left", "outer"] else None
left_mask = (
left_relative_predicates if join_def.type in ["right", "outer"] else None
)
right_mask = (
right_relative_predicates if join_def.type in ["left", "outer"] else None
)

# Public mapping must use JOIN_NAME_REMAPPER to stay in sync with consumers of join result
map_left_id, map_right_id = joining.JOIN_NAME_REMAPPER(
left.column_ids, right.column_ids
)
map_left_id = join_def.get_left_mapping()
map_right_id = join_def.get_right_mapping()
joined_columns = [
_mask_value(left._get_ibis_column(key), left_mask).name(map_left_id[key])
for key in left.column_ids
Expand All @@ -90,11 +93,10 @@ def join_by_row_identity_unordered(
def join_by_row_identity_ordered(
left: compiled.OrderedIR,
right: compiled.OrderedIR,
*,
how: str,
join_def: join_def.JoinDefinition,
) -> compiled.OrderedIR:
"""Compute join when we are joining by row identity not a specific column."""
if how not in SUPPORTED_ROW_IDENTITY_HOW:
if join_def.type not in SUPPORTED_ROW_IDENTITY_HOW:
raise NotImplementedError(
f"Only how='outer','left','inner' currently supported. {constants.FEEDBACK_LINK}"
)
Expand All @@ -117,17 +119,20 @@ def join_by_row_identity_ordered(
combined_predicates = []
if left_predicates or right_predicates:
joined_predicates = _join_predicates(
left_predicates, right_predicates, join_type=how
left_predicates, right_predicates, join_type=join_def.type
)
combined_predicates = list(joined_predicates) # builder expects mutable list

left_mask = left_relative_predicates if how in ["right", "outer"] else None
right_mask = right_relative_predicates if how in ["left", "outer"] else None
left_mask = (
left_relative_predicates if join_def.type in ["right", "outer"] else None
)
right_mask = (
right_relative_predicates if join_def.type in ["left", "outer"] else None
)

# Public mapping must use JOIN_NAME_REMAPPER to stay in sync with consumers of join result
lpublicmapping, rpublicmapping = joining.JOIN_NAME_REMAPPER(
left.column_ids, right.column_ids
)
lpublicmapping = join_def.get_left_mapping()
rpublicmapping = join_def.get_right_mapping()
lhiddenmapping, rhiddenmapping = joining.JoinNameRemapper(namespace="hidden")(
left._hidden_column_ids, right._hidden_column_ids
)
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.