diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 598c32670e..8c90828091 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -2286,13 +2286,13 @@ def to_sql_query( idx_labels, ) - def cached(self, *, optimize_offsets=False, force: bool = False) -> None: + def cached(self, *, force: bool = False, session_aware: bool = False) -> None: """Write the block to a session table.""" # use a heuristic for whether something needs to be cached if (not force) and self.session._is_trivially_executable(self.expr): return - if optimize_offsets: - self.session._cache_with_offsets(self.expr) + elif session_aware: + self.session._cache_with_session_awareness(self.expr) else: self.session._cache_with_cluster_cols( self.expr, cluster_cols=self.index_columns diff --git a/bigframes/core/pruning.py b/bigframes/core/pruning.py new file mode 100644 index 0000000000..55165a616c --- /dev/null +++ b/bigframes/core/pruning.py @@ -0,0 +1,77 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bigframes.core.expression as ex +import bigframes.core.schema as schemata +import bigframes.dtypes +import bigframes.operations as ops + +LOW_CARDINALITY_TYPES = [bigframes.dtypes.BOOL_DTYPE] + +COMPARISON_OP_TYPES = tuple( + type(i) + for i in ( + ops.eq_op, + ops.eq_null_match_op, + ops.ne_op, + ops.gt_op, + ops.ge_op, + ops.lt_op, + ops.le_op, + ) +) + + +def cluster_cols_for_predicate( + predicate: ex.Expression, schema: schemata.ArraySchema +) -> list[str]: + """Try to determine cluster col candidates that work with given predicates.""" + # TODO: Prioritize based on predicted selectivity (eg. equality conditions are probably very selective) + if isinstance(predicate, ex.UnboundVariableExpression): + cols = [predicate.id] + elif isinstance(predicate, ex.OpExpression): + op = predicate.op + # TODO: Support geo predicates, which support pruning if clustered (other than st_disjoint) + # https://cloud.google.com/bigquery/docs/reference/standard-sql/geography_functions + if isinstance(op, COMPARISON_OP_TYPES): + cols = cluster_cols_for_comparison(predicate.inputs[0], predicate.inputs[1]) + elif isinstance(op, (type(ops.invert_op))): + cols = cluster_cols_for_predicate(predicate.inputs[0], schema) + elif isinstance(op, (type(ops.and_op), type(ops.or_op))): + left_cols = cluster_cols_for_predicate(predicate.inputs[0], schema) + right_cols = cluster_cols_for_predicate(predicate.inputs[1], schema) + cols = [*left_cols, *[col for col in right_cols if col not in left_cols]] + else: + cols = [] + else: + # Constant + cols = [] + return [ + col for col in cols if bigframes.dtypes.is_clusterable(schema.get_type(col)) + ] + + +def cluster_cols_for_comparison( + left_ex: ex.Expression, right_ex: ex.Expression +) -> list[str]: + # TODO: Try to normalize expressions such that one side is a single variable. + # eg. Convert -cola>=3 to cola<-3 and colb+3 < 4 to colb < 1 + if left_ex.is_const: + # There are some invertible ops that would also be ok + if isinstance(right_ex, ex.UnboundVariableExpression): + return [right_ex.id] + elif right_ex.is_const: + if isinstance(left_ex, ex.UnboundVariableExpression): + return [left_ex.id] + return [] diff --git a/bigframes/core/tree_properties.py b/bigframes/core/tree_properties.py index 2847a8f7f1..846cf50d77 100644 --- a/bigframes/core/tree_properties.py +++ b/bigframes/core/tree_properties.py @@ -15,7 +15,7 @@ import functools import itertools -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Sequence import bigframes.core.nodes as nodes @@ -91,6 +91,43 @@ def _node_counts_inner( ) +def count_nodes(forest: Sequence[nodes.BigFrameNode]) -> dict[nodes.BigFrameNode, int]: + """ + Counts the number of instances of each subtree present within a forest. + + Memoizes internally to accelerate execution, but cache not persisted (not reused between invocations). + + Args: + forest (Sequence of BigFrameNode): + The roots of each tree in the forest + + Returns: + dict[BigFramesNode, int]: The number of occurences of each subtree. + """ + + def _combine_counts( + left: Dict[nodes.BigFrameNode, int], right: Dict[nodes.BigFrameNode, int] + ) -> Dict[nodes.BigFrameNode, int]: + return { + key: left.get(key, 0) + right.get(key, 0) + for key in itertools.chain(left.keys(), right.keys()) + } + + empty_counts: Dict[nodes.BigFrameNode, int] = {} + + @functools.cache + def _node_counts_inner( + subtree: nodes.BigFrameNode, + ) -> Dict[nodes.BigFrameNode, int]: + """Helper function to count occurences of duplicate nodes in a subtree. Considers only nodes in a complexity range""" + child_counts = [_node_counts_inner(child) for child in subtree.child_nodes] + node_counts = functools.reduce(_combine_counts, child_counts, empty_counts) + return _combine_counts(node_counts, {subtree: 1}) + + counts = [_node_counts_inner(root) for root in forest] + return functools.reduce(_combine_counts, counts, empty_counts) + + def replace_nodes( root: nodes.BigFrameNode, replacements: dict[nodes.BigFrameNode, nodes.BigFrameNode], diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index ced1c215e5..5de8f896a9 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -74,52 +74,95 @@ class SimpleDtypeInfo: logical_bytes: int = ( 8 # this is approximate only, some types are variably sized, also, compression ) + orderable: bool = False + clusterable: bool = False # TODO: Missing BQ types: INTERVAL, JSON, RANGE # TODO: Add mappings to python types SIMPLE_TYPES = ( SimpleDtypeInfo( - dtype=INT_DTYPE, arrow_dtype=pa.int64(), type_kind=("INT64", "INTEGER") + dtype=INT_DTYPE, + arrow_dtype=pa.int64(), + type_kind=("INT64", "INTEGER"), + orderable=True, + clusterable=True, ), SimpleDtypeInfo( - dtype=FLOAT_DTYPE, arrow_dtype=pa.float64(), type_kind=("FLOAT64", "FLOAT") + dtype=FLOAT_DTYPE, + arrow_dtype=pa.float64(), + type_kind=("FLOAT64", "FLOAT"), + orderable=True, ), SimpleDtypeInfo( dtype=BOOL_DTYPE, arrow_dtype=pa.bool_(), type_kind=("BOOL", "BOOLEAN"), logical_bytes=1, + orderable=True, + clusterable=True, ), - SimpleDtypeInfo(dtype=STRING_DTYPE, arrow_dtype=pa.string(), type_kind=("STRING",)), SimpleDtypeInfo( - dtype=DATE_DTYPE, arrow_dtype=pa.date32(), type_kind=("DATE",), logical_bytes=4 + dtype=STRING_DTYPE, + arrow_dtype=pa.string(), + type_kind=("STRING",), + orderable=True, + clusterable=True, ), - SimpleDtypeInfo(dtype=TIME_DTYPE, arrow_dtype=pa.time64("us"), type_kind=("TIME",)), SimpleDtypeInfo( - dtype=DATETIME_DTYPE, arrow_dtype=pa.timestamp("us"), type_kind=("DATETIME",) + dtype=DATE_DTYPE, + arrow_dtype=pa.date32(), + type_kind=("DATE",), + logical_bytes=4, + orderable=True, + clusterable=True, + ), + SimpleDtypeInfo( + dtype=TIME_DTYPE, + arrow_dtype=pa.time64("us"), + type_kind=("TIME",), + orderable=True, + ), + SimpleDtypeInfo( + dtype=DATETIME_DTYPE, + arrow_dtype=pa.timestamp("us"), + type_kind=("DATETIME",), + orderable=True, + clusterable=True, ), SimpleDtypeInfo( dtype=TIMESTAMP_DTYPE, arrow_dtype=pa.timestamp("us", tz="UTC"), type_kind=("TIMESTAMP",), + orderable=True, + clusterable=True, + ), + SimpleDtypeInfo( + dtype=BYTES_DTYPE, arrow_dtype=pa.binary(), type_kind=("BYTES",), orderable=True ), - SimpleDtypeInfo(dtype=BYTES_DTYPE, arrow_dtype=pa.binary(), type_kind=("BYTES",)), SimpleDtypeInfo( dtype=NUMERIC_DTYPE, arrow_dtype=pa.decimal128(38, 9), type_kind=("NUMERIC",), logical_bytes=16, + orderable=True, + clusterable=True, ), SimpleDtypeInfo( dtype=BIGNUMERIC_DTYPE, arrow_dtype=pa.decimal256(76, 38), type_kind=("BIGNUMERIC",), logical_bytes=32, + orderable=True, + clusterable=True, ), # Geo has no corresponding arrow dtype SimpleDtypeInfo( - dtype=GEO_DTYPE, arrow_dtype=None, type_kind=("GEOGRAPHY",), logical_bytes=40 + dtype=GEO_DTYPE, + arrow_dtype=None, + type_kind=("GEOGRAPHY",), + logical_bytes=40, + clusterable=True, ), ) @@ -209,9 +252,25 @@ def is_comparable(type: ExpressionType) -> bool: return (type is not None) and is_orderable(type) +_ORDERABLE_SIMPLE_TYPES = set( + mapping.dtype for mapping in SIMPLE_TYPES if mapping.orderable +) + + def is_orderable(type: ExpressionType) -> bool: # On BQ side, ARRAY, STRUCT, GEOGRAPHY, JSON are not orderable - return not is_array_like(type) and not is_struct_like(type) and (type != GEO_DTYPE) + return type in _ORDERABLE_SIMPLE_TYPES + + +_CLUSTERABLE_SIMPLE_TYPES = set( + mapping.dtype for mapping in SIMPLE_TYPES if mapping.clusterable +) + + +def is_clusterable(type: ExpressionType) -> bool: + # https://cloud.google.com/bigquery/docs/clustered-tables#cluster_column_types + # This is based on default database type mapping, could in theory represent in non-default bq type to cluster. + return type in _CLUSTERABLE_SIMPLE_TYPES def is_bool_coercable(type: ExpressionType) -> bool: diff --git a/bigframes/series.py b/bigframes/series.py index eda95fa1e8..57543abef3 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -623,6 +623,40 @@ def head(self, n: int = 5) -> Series: def tail(self, n: int = 5) -> Series: return typing.cast(Series, self.iloc[-n:]) + def peek(self, n: int = 5, *, force: bool = True) -> pandas.DataFrame: + """ + Preview n arbitrary elements from the series without guarantees about row selection or ordering. + + ``Series.peek(force=False)`` will always be very fast, but will not succeed if data requires + full data scanning. Using ``force=True`` will always succeed, but may be perform queries. + Query results will be cached so that future steps will benefit from these queries. + + Args: + n (int, default 5): + The number of rows to select from the series. Which N rows are returned is non-deterministic. + force (bool, default True): + If the data cannot be peeked efficiently, the series will instead be fully materialized as part + of the operation if ``force=True``. If ``force=False``, the operation will throw a ValueError. + Returns: + pandas.Series: A pandas Series with n rows. + + Raises: + ValueError: If force=False and data cannot be efficiently peeked. + """ + maybe_result = self._block.try_peek(n) + if maybe_result is None: + if force: + self._cached() + maybe_result = self._block.try_peek(n, force=True) + assert maybe_result is not None + else: + raise ValueError( + "Cannot peek efficiently when data has aggregates, joins or window functions applied. Use force=True to fully compute dataframe." + ) + as_series = maybe_result.squeeze(axis=1) + as_series.name = self.name + return as_series + def nlargest(self, n: int = 5, keep: str = "first") -> Series: if keep not in ("first", "last", "all"): raise ValueError("'keep must be one of 'first', 'last', or 'all'") @@ -1419,7 +1453,7 @@ def apply( # return Series with materialized result so that any error in the remote # function is caught early - materialized_series = result_series._cached() + materialized_series = result_series._cached(session_aware=False) return materialized_series def combine( @@ -1794,10 +1828,11 @@ def cache(self): Returns: Series: Self """ - return self._cached(force=True) + # Do not use session-aware cashing if user-requested + return self._cached(force=True, session_aware=False) - def _cached(self, *, force: bool = True) -> Series: - self._block.cached(force=force) + def _cached(self, *, force: bool = True, session_aware: bool = True) -> Series: + self._block.cached(force=force, session_aware=session_aware) return self def _optimize_query_complexity(self): diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 3e8133df48..a4c926de72 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -16,7 +16,6 @@ from __future__ import annotations -import collections.abc import copy import datetime import itertools @@ -84,6 +83,7 @@ import bigframes.core.guid import bigframes.core.nodes as nodes import bigframes.core.ordering as order +import bigframes.core.pruning import bigframes.core.schema as schemata import bigframes.core.tree_properties as traversals import bigframes.core.tree_properties as tree_properties @@ -100,6 +100,7 @@ import bigframes.session._io.bigquery as bf_io_bigquery import bigframes.session._io.bigquery.read_gbq_table as bf_read_gbq_table import bigframes.session.clients +import bigframes.session.planner import bigframes.version # Avoid circular imports. @@ -342,13 +343,15 @@ def session_id(self): @property def objects( self, - ) -> collections.abc.Set[ + ) -> Iterable[ Union[ bigframes.core.indexes.Index, bigframes.series.Series, dataframe.DataFrame ] ]: + still_alive = [i for i in self._objects if i() is not None] + self._objects = still_alive # Create a set with strong references, be careful not to hold onto this needlessly, as will prevent garbage collection. - return set(i() for i in self._objects if i() is not None) # type: ignore + return tuple(i() for i in self._objects if i() is not None) # type: ignore @property def _project(self): @@ -1876,21 +1879,34 @@ def _cache_with_offsets(self, array_value: core.ArrayValue): raise ValueError( "Caching with offsets only supported in strictly ordered mode." ) + offset_column = bigframes.core.guid.generate_guid("bigframes_offsets") sql = bigframes.core.compile.compile_unordered( self._with_cached_executions( - array_value.promote_offsets("bigframes_offsets").node + array_value.promote_offsets(offset_column).node ) ) tmp_table = self._sql_to_temp_table( - sql, cluster_cols=["bigframes_offsets"], api_name="cached" + sql, cluster_cols=[offset_column], api_name="cached" ) cached_replacement = array_value.as_cached( cache_table=self.bqclient.get_table(tmp_table), - ordering=order.ExpressionOrdering.from_offset_col("bigframes_offsets"), + ordering=order.ExpressionOrdering.from_offset_col(offset_column), ).node self._cached_executions[array_value.node] = cached_replacement + def _cache_with_session_awareness(self, array_value: core.ArrayValue) -> None: + # this is the occurence count across the whole session + forest = [obj._block.expr.node for obj in self.objects] + # These node types are cheap to re-compute + target, cluster_cols = bigframes.session.planner.session_aware_cache_plan( + array_value.node, forest + ) + if len(cluster_cols) > 0: + self._cache_with_cluster_cols(core.ArrayValue(target), cluster_cols) + else: + self._cache_with_offsets(core.ArrayValue(target)) + def _simplify_with_caching(self, array_value: core.ArrayValue): """Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces.""" # Apply existing caching first diff --git a/bigframes/session/planner.py b/bigframes/session/planner.py new file mode 100644 index 0000000000..2a74521b43 --- /dev/null +++ b/bigframes/session/planner.py @@ -0,0 +1,74 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import itertools +from typing import Sequence, Tuple + +import bigframes.core.expression as ex +import bigframes.core.nodes as nodes +import bigframes.core.pruning as predicate_pruning +import bigframes.core.tree_properties as traversals + + +def session_aware_cache_plan( + root: nodes.BigFrameNode, session_forest: Sequence[nodes.BigFrameNode] +) -> Tuple[nodes.BigFrameNode, list[str]]: + """ + Determines the best node to cache given a target and a list of object roots for objects in a session. + + Returns the node to cache, and optionally a clustering column. + """ + node_counts = traversals.count_nodes(session_forest) + # These node types are cheap to re-compute, so it makes more sense to cache their children. + de_cachable_types = (nodes.FilterNode, nodes.ProjectionNode) + caching_target = cur_node = root + caching_target_refs = node_counts.get(caching_target, 0) + + filters: list[ + ex.Expression + ] = [] # accumulate filters into this as traverse downwards + clusterable_cols: set[str] = set() + while isinstance(cur_node, de_cachable_types): + if isinstance(cur_node, nodes.FilterNode): + # Filter node doesn't define any variables, so no need to chain expressions + filters.append(cur_node.predicate) + elif isinstance(cur_node, nodes.ProjectionNode): + # Projection defines the variables that are used in the filter expressions, need to substitute variables with their scalar expressions + # that instead reference variables in the child node. + bindings = {name: expr for expr, name in cur_node.assignments} + filters = [i.bind_all_variables(bindings) for i in filters] + else: + raise ValueError(f"Unexpected de-cached node: {cur_node}") + + cur_node = cur_node.child + cur_node_refs = node_counts.get(cur_node, 0) + if cur_node_refs > caching_target_refs: + caching_target, caching_target_refs = cur_node, cur_node_refs + schema = cur_node.schema + # Cluster cols only consider the target object and not other sesssion objects + clusterable_cols = set( + itertools.chain.from_iterable( + map( + lambda f: predicate_pruning.cluster_cols_for_predicate( + f, schema + ), + filters, + ) + ) + ) + # BQ supports up to 4 cluster columns, just prioritize by alphabetical ordering + # TODO: Prioritize caching columns by estimated filter selectivity + return caching_target, sorted(list(clusterable_cols))[:4] diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index 3e21418f2f..cb28686d59 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -1970,6 +1970,70 @@ def test_head_then_series_operation(scalars_dfs): ) +def test_series_peek(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + peek_result = scalars_df["float64_col"].peek(n=3, force=False) + pd.testing.assert_series_equal( + peek_result, + scalars_pandas_df["float64_col"].reindex_like(peek_result), + ) + + +def test_series_peek_multi_index(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_series = scalars_df.set_index(["string_col", "bool_col"])["float64_col"] + bf_series.name = ("2-part", "name") + pd_series = scalars_pandas_df.set_index(["string_col", "bool_col"])["float64_col"] + pd_series.name = ("2-part", "name") + peek_result = bf_series.peek(n=3, force=False) + pd.testing.assert_series_equal( + peek_result, + pd_series.reindex_like(peek_result), + ) + + +def test_series_peek_filtered(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + peek_result = scalars_df[scalars_df.int64_col > 0]["float64_col"].peek( + n=3, force=False + ) + pd_result = scalars_pandas_df[scalars_pandas_df.int64_col > 0]["float64_col"] + pd.testing.assert_series_equal( + peek_result, + pd_result.reindex_like(peek_result), + ) + + +@skip_legacy_pandas +def test_series_peek_force(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + cumsum_df = scalars_df[["int64_col", "int64_too"]].cumsum() + df_filtered = cumsum_df[cumsum_df.int64_col > 0]["int64_too"] + peek_result = df_filtered.peek(n=3, force=True) + pd_cumsum_df = scalars_pandas_df[["int64_col", "int64_too"]].cumsum() + pd_result = pd_cumsum_df[pd_cumsum_df.int64_col > 0]["int64_too"] + pd.testing.assert_series_equal( + peek_result, + pd_result.reindex_like(peek_result), + ) + + +@skip_legacy_pandas +def test_series_peek_force_float(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + cumsum_df = scalars_df[["int64_col", "float64_col"]].cumsum() + df_filtered = cumsum_df[cumsum_df.float64_col > 0]["float64_col"] + peek_result = df_filtered.peek(n=3, force=True) + pd_cumsum_df = scalars_pandas_df[["int64_col", "float64_col"]].cumsum() + pd_result = pd_cumsum_df[pd_cumsum_df.float64_col > 0]["float64_col"] + pd.testing.assert_series_equal( + peek_result, + pd_result.reindex_like(peek_result), + ) + + def test_shift(scalars_df_index, scalars_pandas_df_index): col_name = "int64_col" bf_result = scalars_df_index[col_name].shift().to_pandas() diff --git a/tests/unit/test_planner.py b/tests/unit/test_planner.py new file mode 100644 index 0000000000..2e276d0f1a --- /dev/null +++ b/tests/unit/test_planner.py @@ -0,0 +1,121 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import unittest.mock as mock + +import google.cloud.bigquery +import pandas as pd + +import bigframes.core as core +import bigframes.core.expression as ex +import bigframes.core.schema +import bigframes.operations as ops +import bigframes.session.planner as planner + +TABLE_REF = google.cloud.bigquery.TableReference.from_string("project.dataset.table") +SCHEMA = ( + google.cloud.bigquery.SchemaField("col_a", "INTEGER"), + google.cloud.bigquery.SchemaField("col_b", "INTEGER"), +) +TABLE = google.cloud.bigquery.Table( + table_ref=TABLE_REF, + schema=SCHEMA, +) +FAKE_SESSION = mock.create_autospec(bigframes.Session, instance=True) +type(FAKE_SESSION)._strictly_ordered = mock.PropertyMock(return_value=True) +LEAF: core.ArrayValue = core.ArrayValue.from_table( + session=FAKE_SESSION, + table=TABLE, + schema=bigframes.core.schema.ArraySchema.from_bq_table(TABLE), +) + + +def test_session_aware_caching_project_filter(): + """ + Test that if a node is filtered by a column, the node is cached pre-filter and clustered by the filter column. + """ + session_objects = [LEAF, LEAF.assign_constant("col_c", 4, pd.Int64Dtype())] + target = LEAF.assign_constant("col_c", 4, pd.Int64Dtype()).filter( + ops.gt_op.as_expr("col_a", ex.const(3)) + ) + result, cluster_cols = planner.session_aware_cache_plan( + target.node, [obj.node for obj in session_objects] + ) + assert result == LEAF.node + assert cluster_cols == ["col_a"] + + +def test_session_aware_caching_project_multi_filter(): + """ + Test that if a node is filtered by multiple columns, all of them are in the cluster cols + """ + session_objects = [LEAF, LEAF.assign_constant("col_c", 4, pd.Int64Dtype())] + predicate_1a = ops.gt_op.as_expr("col_a", ex.const(3)) + predicate_1b = ops.lt_op.as_expr("col_a", ex.const(55)) + predicate_1 = ops.and_op.as_expr(predicate_1a, predicate_1b) + predicate_3 = ops.eq_op.as_expr("col_b", ex.const(1)) + target = ( + LEAF.filter(predicate_1) + .assign_constant("col_c", 4, pd.Int64Dtype()) + .filter(predicate_3) + ) + result, cluster_cols = planner.session_aware_cache_plan( + target.node, [obj.node for obj in session_objects] + ) + assert result == LEAF.node + assert cluster_cols == ["col_a", "col_b"] + + +def test_session_aware_caching_unusable_filter(): + """ + Test that if a node is filtered by multiple columns in the same comparison, the node is cached pre-filter and not clustered by either column. + + Most filters with multiple column references cannot be used for scan pruning, as they cannot be converted to fixed value ranges. + """ + session_objects = [LEAF, LEAF.assign_constant("col_c", 4, pd.Int64Dtype())] + target = LEAF.assign_constant("col_c", 4, pd.Int64Dtype()).filter( + ops.gt_op.as_expr("col_a", "col_b") + ) + result, cluster_cols = planner.session_aware_cache_plan( + target.node, [obj.node for obj in session_objects] + ) + assert result == LEAF.node + assert cluster_cols == [] + + +def test_session_aware_caching_fork_after_window_op(): + """ + Test that caching happens only after an windowed operation, but before filtering, projecting. + + Windowing is expensive, so caching should always compute the window function, in order to avoid later recomputation. + """ + other = LEAF.promote_offsets("offsets_col").assign_constant( + "col_d", 5, pd.Int64Dtype() + ) + target = ( + LEAF.promote_offsets("offsets_col") + .assign_constant("col_c", 4, pd.Int64Dtype()) + .filter( + ops.eq_op.as_expr("col_a", ops.add_op.as_expr(ex.const(4), ex.const(3))) + ) + ) + result, cluster_cols = planner.session_aware_cache_plan( + target.node, + [ + other.node, + ], + ) + assert result == LEAF.promote_offsets("offsets_col").node + assert cluster_cols == ["col_a"]