Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

fix: Self-join optimization doesn't needlessly invalidate caching #797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions 12 bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,11 @@ def try_align_as_projection(
join_type: join_def.JoinType,
mappings: typing.Tuple[join_def.JoinColumnMapping, ...],
) -> typing.Optional[ArrayValue]:
left_side = bigframes.core.rewrite.SquashedSelect.from_node(self.node)
right_side = bigframes.core.rewrite.SquashedSelect.from_node(other.node)
result = left_side.maybe_merge(right_side, join_type, mappings)
result = bigframes.core.rewrite.join_as_projection(
self.node, other.node, mappings, join_type
)
if result is not None:
return ArrayValue(result.expand())
return ArrayValue(result)
return None

def explode(self, column_ids: typing.Sequence[str]) -> ArrayValue:
Expand All @@ -528,7 +528,3 @@ def _uniform_sampling(self, fraction: float) -> ArrayValue:
The row numbers of result is non-deterministic, avoid to use.
"""
return ArrayValue(nodes.RandomSampleNode(self.node, fraction))

def merge_projections(self) -> ArrayValue:
new_node = bigframes.core.rewrite.maybe_squash_projection(self.node)
return ArrayValue(new_node)
108 changes: 75 additions & 33 deletions 108 bigframes/core/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,17 @@

Selection = Tuple[Tuple[scalar_exprs.Expression, str], ...]

REWRITABLE_NODE_TYPES = (
nodes.ProjectionNode,
nodes.FilterNode,
nodes.ReversedNode,
nodes.OrderByNode,
)


@dataclasses.dataclass(frozen=True)
class SquashedSelect:
"""Squash together as many nodes as possible, separating out the projection, filter and reordering expressions."""
"""Squash nodes together until target node, separating out the projection, filter and reordering expressions."""

root: nodes.BigFrameNode
columns: Tuple[Tuple[scalar_exprs.Expression, str], ...]
Expand All @@ -38,25 +45,25 @@ class SquashedSelect:
reverse_root: bool = False

@classmethod
def from_node(
cls, node: nodes.BigFrameNode, projections_only: bool = False
def from_node_span(
cls, node: nodes.BigFrameNode, target: nodes.BigFrameNode
) -> SquashedSelect:
if isinstance(node, nodes.ProjectionNode):
return cls.from_node(node.child, projections_only=projections_only).project(
node.assignments
)
elif not projections_only and isinstance(node, nodes.FilterNode):
return cls.from_node(node.child).filter(node.predicate)
elif not projections_only and isinstance(node, nodes.ReversedNode):
return cls.from_node(node.child).reverse()
elif not projections_only and isinstance(node, nodes.OrderByNode):
return cls.from_node(node.child).order_with(node.by)
else:
if node == target:
selection = tuple(
(scalar_exprs.UnboundVariableExpression(id), id)
for id in get_node_column_ids(node)
)
return cls(node, selection, None, ())
if isinstance(node, nodes.ProjectionNode):
return cls.from_node_span(node.child, target).project(node.assignments)
elif isinstance(node, nodes.FilterNode):
return cls.from_node_span(node.child, target).filter(node.predicate)
elif isinstance(node, nodes.ReversedNode):
return cls.from_node_span(node.child, target).reverse()
elif isinstance(node, nodes.OrderByNode):
return cls.from_node_span(node.child, target).order_with(node.by)
else:
raise ValueError(f"Cannot rewrite node {node}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to raise InternalError or add an assert that the target node must be the subtree of the given node? Or rename from_node into from_subnode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be entirely impossible to reach this error through any public api. renamed to from_node_span


@property
def column_lookup(self) -> Mapping[str, scalar_exprs.Expression]:
Expand Down Expand Up @@ -98,9 +105,10 @@ def order_with(self, by: Tuple[order.OrderingExpression, ...]):
self.root, self.columns, self.predicate, new_ordering, self.reverse_root
)

def can_join(
def can_merge(
self, right: SquashedSelect, join_def: join_defs.JoinDefinition
) -> bool:
"""Determines whether the two selections can be merged into a single selection."""
if join_def.type == "cross":
# Cannot convert cross join to projection
return False
Expand All @@ -116,14 +124,14 @@ def can_join(
return False
return True

def maybe_merge(
def merge(
self,
right: SquashedSelect,
join_type: join_defs.JoinType,
mappings: Tuple[join_defs.JoinColumnMapping, ...],
) -> Optional[SquashedSelect]:
) -> SquashedSelect:
if self.root != right.root:
return None
raise ValueError("Cannot merge expressions with different roots")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to raise InternalError and ask customer to share their user case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far we don't do this for other similar errors - couldn't find InternalError anywhere in the code base. Maybe we should automatically wrap all exceptions with feedback link? Seems out of scope

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be SystemError in the python exceptions?

# Mask columns and remap names to expected schema
lselection = self.columns
rselection = right.columns
Expand Down Expand Up @@ -196,28 +204,40 @@ def expand(self) -> nodes.BigFrameNode:
return nodes.ProjectionNode(child=root, assignments=self.columns)


def maybe_squash_projection(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
if isinstance(node, nodes.ProjectionNode) and isinstance(
node.child, nodes.ProjectionNode
):
# Conservative approach, only squash consecutive projections, even though could also squash filters, reorderings
return SquashedSelect.from_node(node, projections_only=True).expand()
return node


def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode:
left_side = SquashedSelect.from_node(join_node.left_child)
right_side = SquashedSelect.from_node(join_node.right_child)
if left_side.can_join(right_side, join_node.join):
merged = left_side.maybe_merge(
rewrite_common_node = common_selection_root(
join_node.left_child, join_node.right_child
)
if rewrite_common_node is None:
return join_node
left_side = SquashedSelect.from_node_span(join_node.left_child, rewrite_common_node)
right_side = SquashedSelect.from_node_span(
join_node.right_child, rewrite_common_node
)
if left_side.can_merge(right_side, join_node.join):
return left_side.merge(
right_side, join_node.join.type, join_node.join.mappings
)
).expand()
return join_node


def join_as_projection(
l_node: nodes.BigFrameNode,
r_node: nodes.BigFrameNode,
mappings: Tuple[join_defs.JoinColumnMapping, ...],
how: join_defs.JoinType,
) -> Optional[nodes.BigFrameNode]:
rewrite_common_node = common_selection_root(l_node, r_node)
if rewrite_common_node is not None:
left_side = SquashedSelect.from_node_span(l_node, rewrite_common_node)
right_side = SquashedSelect.from_node_span(r_node, rewrite_common_node)
merged = left_side.merge(right_side, how, mappings)
assert (
merged is not None
), "Couldn't merge nodes. This shouldn't happen. Please share full stacktrace with the BigQuery DataFrames team at bigframes-feedback@google.com."
return merged.expand()
else:
return join_node
return None


def remap_names(
Expand Down Expand Up @@ -311,3 +331,25 @@ def get_node_column_ids(node: nodes.BigFrameNode) -> Tuple[str, ...]:
import bigframes.core

return tuple(bigframes.core.ArrayValue(node).column_ids)


def common_selection_root(
l_tree: nodes.BigFrameNode, r_tree: nodes.BigFrameNode
) -> Optional[nodes.BigFrameNode]:
"""Find common subtree between join subtrees"""
chelsea-lin marked this conversation as resolved.
Show resolved Hide resolved
l_node = l_tree
l_nodes: set[nodes.BigFrameNode] = set()
while isinstance(l_node, REWRITABLE_NODE_TYPES):
l_nodes.add(l_node)
l_node = l_node.child
l_nodes.add(l_node)

r_node = r_tree
while isinstance(r_node, REWRITABLE_NODE_TYPES):
if r_node in l_nodes:
return r_node
r_node = r_node.child

if r_node in l_nodes:
chelsea-lin marked this conversation as resolved.
Show resolved Hide resolved
return r_node
return None
12 changes: 12 additions & 0 deletions 12 tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4328,6 +4328,18 @@ def test_df_cached(scalars_df_index):
pandas.testing.assert_frame_equal(df.to_pandas(), df_cached_copy.to_pandas())


def test_df_cache_with_implicit_join(scalars_df_index):
"""expectation is that cache will be used, but no explicit join will be performed"""
df = scalars_df_index[["int64_col", "int64_too"]].sort_index().reset_index() + 3
df.cache()
bf_result = df + (df * 2)
sql = bf_result.sql

# Very crude asserts, want sql to not use join and not use base table, only reference cached table
assert "JOIN" not in sql
assert "bigframes_testing" not in sql


def test_df_dot_inline(session):
df1 = pd.DataFrame([[1, 2, 3], [2, 5, 7]])
df2 = pd.DataFrame([[2, 4, 8], [1, 5, 10], [3, 6, 9]])
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.