diff --git a/bigframes/core/__init__.py b/bigframes/core/__init__.py index 2508814894..89ef5f525e 100644 --- a/bigframes/core/__init__.py +++ b/bigframes/core/__init__.py @@ -505,11 +505,11 @@ def try_align_as_projection( join_type: join_def.JoinType, mappings: typing.Tuple[join_def.JoinColumnMapping, ...], ) -> typing.Optional[ArrayValue]: - left_side = bigframes.core.rewrite.SquashedSelect.from_node(self.node) - right_side = bigframes.core.rewrite.SquashedSelect.from_node(other.node) - result = left_side.maybe_merge(right_side, join_type, mappings) + result = bigframes.core.rewrite.join_as_projection( + self.node, other.node, mappings, join_type + ) if result is not None: - return ArrayValue(result.expand()) + return ArrayValue(result) return None def explode(self, column_ids: typing.Sequence[str]) -> ArrayValue: @@ -528,7 +528,3 @@ def _uniform_sampling(self, fraction: float) -> ArrayValue: The row numbers of result is non-deterministic, avoid to use. """ return ArrayValue(nodes.RandomSampleNode(self.node, fraction)) - - def merge_projections(self) -> ArrayValue: - new_node = bigframes.core.rewrite.maybe_squash_projection(self.node) - return ArrayValue(new_node) diff --git a/bigframes/core/rewrite.py b/bigframes/core/rewrite.py index 15999c0558..101d5cc882 100644 --- a/bigframes/core/rewrite.py +++ b/bigframes/core/rewrite.py @@ -26,10 +26,17 @@ Selection = Tuple[Tuple[scalar_exprs.Expression, str], ...] +REWRITABLE_NODE_TYPES = ( + nodes.ProjectionNode, + nodes.FilterNode, + nodes.ReversedNode, + nodes.OrderByNode, +) + @dataclasses.dataclass(frozen=True) class SquashedSelect: - """Squash together as many nodes as possible, separating out the projection, filter and reordering expressions.""" + """Squash nodes together until target node, separating out the projection, filter and reordering expressions.""" root: nodes.BigFrameNode columns: Tuple[Tuple[scalar_exprs.Expression, str], ...] @@ -38,25 +45,25 @@ class SquashedSelect: reverse_root: bool = False @classmethod - def from_node( - cls, node: nodes.BigFrameNode, projections_only: bool = False + def from_node_span( + cls, node: nodes.BigFrameNode, target: nodes.BigFrameNode ) -> SquashedSelect: - if isinstance(node, nodes.ProjectionNode): - return cls.from_node(node.child, projections_only=projections_only).project( - node.assignments - ) - elif not projections_only and isinstance(node, nodes.FilterNode): - return cls.from_node(node.child).filter(node.predicate) - elif not projections_only and isinstance(node, nodes.ReversedNode): - return cls.from_node(node.child).reverse() - elif not projections_only and isinstance(node, nodes.OrderByNode): - return cls.from_node(node.child).order_with(node.by) - else: + if node == target: selection = tuple( (scalar_exprs.UnboundVariableExpression(id), id) for id in get_node_column_ids(node) ) return cls(node, selection, None, ()) + if isinstance(node, nodes.ProjectionNode): + return cls.from_node_span(node.child, target).project(node.assignments) + elif isinstance(node, nodes.FilterNode): + return cls.from_node_span(node.child, target).filter(node.predicate) + elif isinstance(node, nodes.ReversedNode): + return cls.from_node_span(node.child, target).reverse() + elif isinstance(node, nodes.OrderByNode): + return cls.from_node_span(node.child, target).order_with(node.by) + else: + raise ValueError(f"Cannot rewrite node {node}") @property def column_lookup(self) -> Mapping[str, scalar_exprs.Expression]: @@ -98,9 +105,10 @@ def order_with(self, by: Tuple[order.OrderingExpression, ...]): self.root, self.columns, self.predicate, new_ordering, self.reverse_root ) - def can_join( + def can_merge( self, right: SquashedSelect, join_def: join_defs.JoinDefinition ) -> bool: + """Determines whether the two selections can be merged into a single selection.""" if join_def.type == "cross": # Cannot convert cross join to projection return False @@ -116,14 +124,14 @@ def can_join( return False return True - def maybe_merge( + def merge( self, right: SquashedSelect, join_type: join_defs.JoinType, mappings: Tuple[join_defs.JoinColumnMapping, ...], - ) -> Optional[SquashedSelect]: + ) -> SquashedSelect: if self.root != right.root: - return None + raise ValueError("Cannot merge expressions with different roots") # Mask columns and remap names to expected schema lselection = self.columns rselection = right.columns @@ -196,28 +204,40 @@ def expand(self) -> nodes.BigFrameNode: return nodes.ProjectionNode(child=root, assignments=self.columns) -def maybe_squash_projection(node: nodes.BigFrameNode) -> nodes.BigFrameNode: - if isinstance(node, nodes.ProjectionNode) and isinstance( - node.child, nodes.ProjectionNode - ): - # Conservative approach, only squash consecutive projections, even though could also squash filters, reorderings - return SquashedSelect.from_node(node, projections_only=True).expand() - return node - - def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode: - left_side = SquashedSelect.from_node(join_node.left_child) - right_side = SquashedSelect.from_node(join_node.right_child) - if left_side.can_join(right_side, join_node.join): - merged = left_side.maybe_merge( + rewrite_common_node = common_selection_root( + join_node.left_child, join_node.right_child + ) + if rewrite_common_node is None: + return join_node + left_side = SquashedSelect.from_node_span(join_node.left_child, rewrite_common_node) + right_side = SquashedSelect.from_node_span( + join_node.right_child, rewrite_common_node + ) + if left_side.can_merge(right_side, join_node.join): + return left_side.merge( right_side, join_node.join.type, join_node.join.mappings - ) + ).expand() + return join_node + + +def join_as_projection( + l_node: nodes.BigFrameNode, + r_node: nodes.BigFrameNode, + mappings: Tuple[join_defs.JoinColumnMapping, ...], + how: join_defs.JoinType, +) -> Optional[nodes.BigFrameNode]: + rewrite_common_node = common_selection_root(l_node, r_node) + if rewrite_common_node is not None: + left_side = SquashedSelect.from_node_span(l_node, rewrite_common_node) + right_side = SquashedSelect.from_node_span(r_node, rewrite_common_node) + merged = left_side.merge(right_side, how, mappings) assert ( merged is not None ), "Couldn't merge nodes. This shouldn't happen. Please share full stacktrace with the BigQuery DataFrames team at bigframes-feedback@google.com." return merged.expand() else: - return join_node + return None def remap_names( @@ -311,3 +331,25 @@ def get_node_column_ids(node: nodes.BigFrameNode) -> Tuple[str, ...]: import bigframes.core return tuple(bigframes.core.ArrayValue(node).column_ids) + + +def common_selection_root( + l_tree: nodes.BigFrameNode, r_tree: nodes.BigFrameNode +) -> Optional[nodes.BigFrameNode]: + """Find common subtree between join subtrees""" + l_node = l_tree + l_nodes: set[nodes.BigFrameNode] = set() + while isinstance(l_node, REWRITABLE_NODE_TYPES): + l_nodes.add(l_node) + l_node = l_node.child + l_nodes.add(l_node) + + r_node = r_tree + while isinstance(r_node, REWRITABLE_NODE_TYPES): + if r_node in l_nodes: + return r_node + r_node = r_node.child + + if r_node in l_nodes: + return r_node + return None diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 782ef2d5ea..0aac9e2578 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -4328,6 +4328,18 @@ def test_df_cached(scalars_df_index): pandas.testing.assert_frame_equal(df.to_pandas(), df_cached_copy.to_pandas()) +def test_df_cache_with_implicit_join(scalars_df_index): + """expectation is that cache will be used, but no explicit join will be performed""" + df = scalars_df_index[["int64_col", "int64_too"]].sort_index().reset_index() + 3 + df.cache() + bf_result = df + (df * 2) + sql = bf_result.sql + + # Very crude asserts, want sql to not use join and not use base table, only reference cached table + assert "JOIN" not in sql + assert "bigframes_testing" not in sql + + def test_df_dot_inline(session): df1 = pd.DataFrame([[1, 2, 3], [2, 5, 7]]) df2 = pd.DataFrame([[2, 4, 8], [1, 5, 10], [3, 6, 9]])