From a338d30013bf4d5806618691c6f74d659541fc79 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Mon, 17 Jun 2024 23:34:50 +0000 Subject: [PATCH 1/6] fix: Self-join optimization doesn't needlessly invalidate caching --- bigframes/core/__init__.py | 12 ++--- bigframes/core/rewrite.py | 90 ++++++++++++++++++++++++++------------ 2 files changed, 66 insertions(+), 36 deletions(-) diff --git a/bigframes/core/__init__.py b/bigframes/core/__init__.py index 2dc6184afc..08beb825c2 100644 --- a/bigframes/core/__init__.py +++ b/bigframes/core/__init__.py @@ -520,11 +520,11 @@ def try_align_as_projection( join_type: join_def.JoinType, mappings: typing.Tuple[join_def.JoinColumnMapping, ...], ) -> typing.Optional[ArrayValue]: - left_side = bigframes.core.rewrite.SquashedSelect.from_node(self.node) - right_side = bigframes.core.rewrite.SquashedSelect.from_node(other.node) - result = left_side.maybe_merge(right_side, join_type, mappings) + result = bigframes.core.rewrite.join_as_projection( + self.node, other.node, mappings, join_type + ) if result is not None: - return ArrayValue(result.expand()) + return ArrayValue(result) return None def explode(self, column_ids: typing.Sequence[str]) -> ArrayValue: @@ -543,7 +543,3 @@ def _uniform_sampling(self, fraction: float) -> ArrayValue: The row numbers of result is non-deterministic, avoid to use. """ return ArrayValue(nodes.RandomSampleNode(self.node, fraction)) - - def merge_projections(self) -> ArrayValue: - new_node = bigframes.core.rewrite.maybe_squash_projection(self.node) - return ArrayValue(new_node) diff --git a/bigframes/core/rewrite.py b/bigframes/core/rewrite.py index 15999c0558..b553533dfe 100644 --- a/bigframes/core/rewrite.py +++ b/bigframes/core/rewrite.py @@ -26,10 +26,17 @@ Selection = Tuple[Tuple[scalar_exprs.Expression, str], ...] +REWRITABLE_NODE_TYPES = ( + nodes.ProjectionNode, + nodes.FilterNode, + nodes.ReversedNode, + nodes.OrderByNode, +) + @dataclasses.dataclass(frozen=True) class SquashedSelect: - """Squash together as many nodes as possible, separating out the projection, filter and reordering expressions.""" + """Squash nodes together until target node, separating out the projection, filter and reordering expressions.""" root: nodes.BigFrameNode columns: Tuple[Tuple[scalar_exprs.Expression, str], ...] @@ -39,24 +46,24 @@ class SquashedSelect: @classmethod def from_node( - cls, node: nodes.BigFrameNode, projections_only: bool = False + cls, node: nodes.BigFrameNode, target: nodes.BigFrameNode ) -> SquashedSelect: - if isinstance(node, nodes.ProjectionNode): - return cls.from_node(node.child, projections_only=projections_only).project( - node.assignments - ) - elif not projections_only and isinstance(node, nodes.FilterNode): - return cls.from_node(node.child).filter(node.predicate) - elif not projections_only and isinstance(node, nodes.ReversedNode): - return cls.from_node(node.child).reverse() - elif not projections_only and isinstance(node, nodes.OrderByNode): - return cls.from_node(node.child).order_with(node.by) - else: + if node == target: selection = tuple( (scalar_exprs.UnboundVariableExpression(id), id) for id in get_node_column_ids(node) ) return cls(node, selection, None, ()) + if isinstance(node, nodes.ProjectionNode): + return cls.from_node(node.child, target).project(node.assignments) + elif isinstance(node, nodes.FilterNode): + return cls.from_node(node.child, target).filter(node.predicate) + elif isinstance(node, nodes.ReversedNode): + return cls.from_node(node.child, target).reverse() + elif isinstance(node, nodes.OrderByNode): + return cls.from_node(node.child, target).order_with(node.by) + else: + raise ValueError(f"Cannot rewrite node {node}") @property def column_lookup(self) -> Mapping[str, scalar_exprs.Expression]: @@ -196,28 +203,33 @@ def expand(self) -> nodes.BigFrameNode: return nodes.ProjectionNode(child=root, assignments=self.columns) -def maybe_squash_projection(node: nodes.BigFrameNode) -> nodes.BigFrameNode: - if isinstance(node, nodes.ProjectionNode) and isinstance( - node.child, nodes.ProjectionNode - ): - # Conservative approach, only squash consecutive projections, even though could also squash filters, reorderings - return SquashedSelect.from_node(node, projections_only=True).expand() - return node +def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode: + rewritten = join_as_projection( + join_node.left_child, + join_node.right_child, + join_node.join.mappings, + join_node.join.type, + ) + return rewritten if rewritten is not None else join_node -def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode: - left_side = SquashedSelect.from_node(join_node.left_child) - right_side = SquashedSelect.from_node(join_node.right_child) - if left_side.can_join(right_side, join_node.join): - merged = left_side.maybe_merge( - right_side, join_node.join.type, join_node.join.mappings - ) +def join_as_projection( + l_node: nodes.BigFrameNode, + r_node: nodes.BigFrameNode, + mappings: Tuple[join_defs.JoinColumnMapping, ...], + how: join_defs.JoinType, +) -> Optional[nodes.BigFrameNode]: + rewrite_common_node = common_subtree(l_node, r_node) + if rewrite_common_node is not None: + left_side = SquashedSelect.from_node(l_node, rewrite_common_node) + right_side = SquashedSelect.from_node(r_node, rewrite_common_node) + merged = left_side.maybe_merge(right_side, how, mappings) assert ( merged is not None ), "Couldn't merge nodes. This shouldn't happen. Please share full stacktrace with the BigQuery DataFrames team at bigframes-feedback@google.com." return merged.expand() else: - return join_node + return None def remap_names( @@ -311,3 +323,25 @@ def get_node_column_ids(node: nodes.BigFrameNode) -> Tuple[str, ...]: import bigframes.core return tuple(bigframes.core.ArrayValue(node).column_ids) + + +def common_subtree( + l_tree: nodes.BigFrameNode, r_tree: nodes.BigFrameNode +) -> Optional[nodes.BigFrameNode]: + """Find common subtree between join subtrees""" + l_node = l_tree + l_nodes: set[nodes.BigFrameNode] = set() + while isinstance(l_node, REWRITABLE_NODE_TYPES): + l_nodes.add(l_node) + l_node = l_node.child + l_nodes.add(l_node) + + r_node = r_tree + while isinstance(r_node, REWRITABLE_NODE_TYPES): + if r_node in l_nodes: + return r_node + r_node = r_node.child + + if r_node in l_nodes: + return r_node + return None From 71d1f1bddd0b26969cd156e47013cbee2d3b1111 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Tue, 18 Jun 2024 17:15:53 +0000 Subject: [PATCH 2/6] ensure join condition is matching same expression --- bigframes/core/rewrite.py | 25 ++++++++++++++----------- tests/system/small/test_dataframe.py | 5 +++++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/bigframes/core/rewrite.py b/bigframes/core/rewrite.py index b553533dfe..4aca7986a5 100644 --- a/bigframes/core/rewrite.py +++ b/bigframes/core/rewrite.py @@ -123,14 +123,14 @@ def can_join( return False return True - def maybe_merge( + def merge( self, right: SquashedSelect, join_type: join_defs.JoinType, mappings: Tuple[join_defs.JoinColumnMapping, ...], - ) -> Optional[SquashedSelect]: + ) -> SquashedSelect: if self.root != right.root: - return None + raise ValueError("Cannot merge expressions with different roots") # Mask columns and remap names to expected schema lselection = self.columns rselection = right.columns @@ -204,13 +204,16 @@ def expand(self) -> nodes.BigFrameNode: def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode: - rewritten = join_as_projection( - join_node.left_child, - join_node.right_child, - join_node.join.mappings, - join_node.join.type, - ) - return rewritten if rewritten is not None else join_node + rewrite_common_node = common_subtree(join_node.left_child, join_node.right_child) + if rewrite_common_node is None: + return join_node + left_side = SquashedSelect.from_node(join_node.left_child, rewrite_common_node) + right_side = SquashedSelect.from_node(join_node.right_child, rewrite_common_node) + if left_side.can_join(right_side, join_node.join): + return left_side.merge( + right_side, join_node.join.type, join_node.join.mappings + ).expand() + return join_node def join_as_projection( @@ -223,7 +226,7 @@ def join_as_projection( if rewrite_common_node is not None: left_side = SquashedSelect.from_node(l_node, rewrite_common_node) right_side = SquashedSelect.from_node(r_node, rewrite_common_node) - merged = left_side.maybe_merge(right_side, how, mappings) + merged = left_side.merge(right_side, how, mappings) assert ( merged is not None ), "Couldn't merge nodes. This shouldn't happen. Please share full stacktrace with the BigQuery DataFrames team at bigframes-feedback@google.com." diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 782ef2d5ea..ff804665fa 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -865,6 +865,11 @@ def test_assign_same_table_different_index_performs_self_join( bf_result = bf_df.assign(new_col=bf_df_2[column_name] * 10).to_pandas() pd_result = pd_df.assign(new_col=pd_df_2[column_name] * 10) + print("pandas") + print(pd_result.to_string()) + print("bigframes") + print(bf_result.to_string()) + pandas.testing.assert_frame_equal(bf_result, pd_result) From 38be0a21bf5753a9d767ae737ceba5f226d04de6 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Tue, 18 Jun 2024 17:20:20 +0000 Subject: [PATCH 3/6] remove print statements --- tests/system/small/test_dataframe.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index ff804665fa..782ef2d5ea 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -865,11 +865,6 @@ def test_assign_same_table_different_index_performs_self_join( bf_result = bf_df.assign(new_col=bf_df_2[column_name] * 10).to_pandas() pd_result = pd_df.assign(new_col=pd_df_2[column_name] * 10) - print("pandas") - print(pd_result.to_string()) - print("bigframes") - print(bf_result.to_string()) - pandas.testing.assert_frame_equal(bf_result, pd_result) From 7904989ee770ca23ec5b19b814bb9065e7e3ab34 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Tue, 18 Jun 2024 17:29:38 +0000 Subject: [PATCH 4/6] add test --- tests/system/small/test_dataframe.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 782ef2d5ea..0aac9e2578 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -4328,6 +4328,18 @@ def test_df_cached(scalars_df_index): pandas.testing.assert_frame_equal(df.to_pandas(), df_cached_copy.to_pandas()) +def test_df_cache_with_implicit_join(scalars_df_index): + """expectation is that cache will be used, but no explicit join will be performed""" + df = scalars_df_index[["int64_col", "int64_too"]].sort_index().reset_index() + 3 + df.cache() + bf_result = df + (df * 2) + sql = bf_result.sql + + # Very crude asserts, want sql to not use join and not use base table, only reference cached table + assert "JOIN" not in sql + assert "bigframes_testing" not in sql + + def test_df_dot_inline(session): df1 = pd.DataFrame([[1, 2, 3], [2, 5, 7]]) df2 = pd.DataFrame([[2, 4, 8], [1, 5, 10], [3, 6, 9]]) From 5fcde7529974e29d4ad5487158097a4308dd18d9 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 20 Jun 2024 21:31:39 +0000 Subject: [PATCH 5/6] rename methods --- bigframes/core/rewrite.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/bigframes/core/rewrite.py b/bigframes/core/rewrite.py index 4aca7986a5..aff44ee683 100644 --- a/bigframes/core/rewrite.py +++ b/bigframes/core/rewrite.py @@ -45,7 +45,7 @@ class SquashedSelect: reverse_root: bool = False @classmethod - def from_node( + def from_node_span( cls, node: nodes.BigFrameNode, target: nodes.BigFrameNode ) -> SquashedSelect: if node == target: @@ -55,13 +55,13 @@ def from_node( ) return cls(node, selection, None, ()) if isinstance(node, nodes.ProjectionNode): - return cls.from_node(node.child, target).project(node.assignments) + return cls.from_node_span(node.child, target).project(node.assignments) elif isinstance(node, nodes.FilterNode): - return cls.from_node(node.child, target).filter(node.predicate) + return cls.from_node_span(node.child, target).filter(node.predicate) elif isinstance(node, nodes.ReversedNode): - return cls.from_node(node.child, target).reverse() + return cls.from_node_span(node.child, target).reverse() elif isinstance(node, nodes.OrderByNode): - return cls.from_node(node.child, target).order_with(node.by) + return cls.from_node_span(node.child, target).order_with(node.by) else: raise ValueError(f"Cannot rewrite node {node}") @@ -105,9 +105,10 @@ def order_with(self, by: Tuple[order.OrderingExpression, ...]): self.root, self.columns, self.predicate, new_ordering, self.reverse_root ) - def can_join( + def can_merge( self, right: SquashedSelect, join_def: join_defs.JoinDefinition ) -> bool: + """Determines whether the two selections can be merged into a single selection.""" if join_def.type == "cross": # Cannot convert cross join to projection return False @@ -207,9 +208,11 @@ def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode: rewrite_common_node = common_subtree(join_node.left_child, join_node.right_child) if rewrite_common_node is None: return join_node - left_side = SquashedSelect.from_node(join_node.left_child, rewrite_common_node) - right_side = SquashedSelect.from_node(join_node.right_child, rewrite_common_node) - if left_side.can_join(right_side, join_node.join): + left_side = SquashedSelect.from_node_span(join_node.left_child, rewrite_common_node) + right_side = SquashedSelect.from_node_span( + join_node.right_child, rewrite_common_node + ) + if left_side.can_merge(right_side, join_node.join): return left_side.merge( right_side, join_node.join.type, join_node.join.mappings ).expand() @@ -224,8 +227,8 @@ def join_as_projection( ) -> Optional[nodes.BigFrameNode]: rewrite_common_node = common_subtree(l_node, r_node) if rewrite_common_node is not None: - left_side = SquashedSelect.from_node(l_node, rewrite_common_node) - right_side = SquashedSelect.from_node(r_node, rewrite_common_node) + left_side = SquashedSelect.from_node_span(l_node, rewrite_common_node) + right_side = SquashedSelect.from_node_span(r_node, rewrite_common_node) merged = left_side.merge(right_side, how, mappings) assert ( merged is not None From 56c8459319dc4e9d3321ebb698e285cbe4871e82 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 20 Jun 2024 21:57:35 +0000 Subject: [PATCH 6/6] rename common_subtree to common_selection_root --- bigframes/core/rewrite.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bigframes/core/rewrite.py b/bigframes/core/rewrite.py index aff44ee683..101d5cc882 100644 --- a/bigframes/core/rewrite.py +++ b/bigframes/core/rewrite.py @@ -205,7 +205,9 @@ def expand(self) -> nodes.BigFrameNode: def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode: - rewrite_common_node = common_subtree(join_node.left_child, join_node.right_child) + rewrite_common_node = common_selection_root( + join_node.left_child, join_node.right_child + ) if rewrite_common_node is None: return join_node left_side = SquashedSelect.from_node_span(join_node.left_child, rewrite_common_node) @@ -225,7 +227,7 @@ def join_as_projection( mappings: Tuple[join_defs.JoinColumnMapping, ...], how: join_defs.JoinType, ) -> Optional[nodes.BigFrameNode]: - rewrite_common_node = common_subtree(l_node, r_node) + rewrite_common_node = common_selection_root(l_node, r_node) if rewrite_common_node is not None: left_side = SquashedSelect.from_node_span(l_node, rewrite_common_node) right_side = SquashedSelect.from_node_span(r_node, rewrite_common_node) @@ -331,7 +333,7 @@ def get_node_column_ids(node: nodes.BigFrameNode) -> Tuple[str, ...]: return tuple(bigframes.core.ArrayValue(node).column_ids) -def common_subtree( +def common_selection_root( l_tree: nodes.BigFrameNode, r_tree: nodes.BigFrameNode ) -> Optional[nodes.BigFrameNode]: """Find common subtree between join subtrees"""