diff --git a/bigframes/core/__init__.py b/bigframes/core/__init__.py index 485a9d79a7..71b1214d01 100644 --- a/bigframes/core/__init__.py +++ b/bigframes/core/__init__.py @@ -268,7 +268,13 @@ def promote_offsets(self) -> Tuple[ArrayValue, str]: def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue: """Append together multiple ArrayValue objects.""" return ArrayValue( - nodes.ConcatNode(children=tuple([self.node, *[val.node for val in other]])) + nodes.ConcatNode( + children=tuple([self.node, *[val.node for val in other]]), + output_ids=tuple( + ids.ColumnId(bigframes.core.guid.generate_guid()) + for id in self.column_ids + ), + ) ) def compute_values(self, assignments: Sequence[ex.Expression]): diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index b0a8903e19..2648c9993f 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -3139,7 +3139,7 @@ def _pd_index_to_array_value( rows = [] labels_as_tuples = utils.index_as_tuples(index) for row_offset in range(len(index)): - id_gen = bigframes.core.identifiers.standard_identifiers() + id_gen = bigframes.core.identifiers.standard_id_strings() row_label = labels_as_tuples[row_offset] row_label = (row_label,) if not isinstance(row_label, tuple) else row_label row = {} diff --git a/bigframes/core/compile/api.py b/bigframes/core/compile/api.py index 86c8fca25a..61eaa63f85 100644 --- a/bigframes/core/compile/api.py +++ b/bigframes/core/compile/api.py @@ -18,14 +18,15 @@ import google.cloud.bigquery as bigquery import bigframes.core.compile.compiler as compiler -import bigframes.core.rewrite as rewrites if TYPE_CHECKING: import bigframes.core.nodes import bigframes.core.ordering import bigframes.core.schema -_STRICT_COMPILER = compiler.Compiler(strict=True) +_STRICT_COMPILER = compiler.Compiler( + strict=True, enable_pruning=True, enable_densify_ids=True +) class SQLCompiler: @@ -34,7 +35,7 @@ def __init__(self, strict: bool = True): def compile_peek(self, node: bigframes.core.nodes.BigFrameNode, n_rows: int) -> str: """Compile node into sql that selects N arbitrary rows, may not execute deterministically.""" - return self._compiler.compile_unordered_ir(node).peek_sql(n_rows) + return self._compiler.compile_peek_sql(node, n_rows) def compile_unordered( self, @@ -44,9 +45,8 @@ def compile_unordered( ) -> str: """Compile node into sql where rows are unsorted, and no ordering information is preserved.""" # TODO: Enable limit pullup, but only if not being used to write to clustered table. - return self._compiler.compile_unordered_ir(node).to_sql( - col_id_overrides=col_id_overrides - ) + output_ids = [col_id_overrides.get(id, id) for id in node.schema.names] + return self._compiler.compile_sql(node, ordered=False, output_ids=output_ids) def compile_ordered( self, @@ -56,10 +56,8 @@ def compile_ordered( ) -> str: """Compile node into sql where rows are sorted with ORDER BY.""" # If we are ordering the query anyways, compiling the slice as a limit is probably a good idea. - new_node, limit = rewrites.pullup_limit_from_slice(node) - return self._compiler.compile_ordered_ir(new_node).to_sql( - col_id_overrides=col_id_overrides, ordered=True, limit=limit - ) + output_ids = [col_id_overrides.get(id, id) for id in node.schema.names] + return self._compiler.compile_sql(node, ordered=True, output_ids=output_ids) def compile_raw( self, @@ -68,13 +66,12 @@ def compile_raw( str, Sequence[bigquery.SchemaField], bigframes.core.ordering.RowOrdering ]: """Compile node into sql that exposes all columns, including hidden ordering-only columns.""" - ir = self._compiler.compile_ordered_ir(node) - sql, schema = ir.raw_sql_and_schema() - return sql, schema, ir._ordering + return self._compiler.compile_raw(node) def test_only_try_evaluate(node: bigframes.core.nodes.BigFrameNode): """Use only for unit testing paths - not fully featured. Will throw exception if fails.""" + node = _STRICT_COMPILER._preprocess(node) ibis = _STRICT_COMPILER.compile_ordered_ir(node)._to_ibis_expr( ordering_mode="unordered" ) @@ -85,9 +82,10 @@ def test_only_ibis_inferred_schema(node: bigframes.core.nodes.BigFrameNode): """Use only for testing paths to ensure ibis inferred schema does not diverge from bigframes inferred schema.""" import bigframes.core.schema + node = _STRICT_COMPILER._preprocess(node) compiled = _STRICT_COMPILER.compile_unordered_ir(node) items = tuple( - bigframes.core.schema.SchemaItem(id, compiled.get_column_type(id)) - for id in compiled.column_ids + bigframes.core.schema.SchemaItem(name, compiled.get_column_type(ibis_id)) + for name, ibis_id in zip(node.schema.names, compiled.column_ids) ) return bigframes.core.schema.ArraySchema(items) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index d02a2c444c..d2783a07e2 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -202,7 +202,12 @@ def _aggregate_base( ) # Must have deterministic ordering, so order by the unique "by" column ordering = TotalOrdering( - tuple([OrderingExpression(column_id) for column_id in by_column_ids]), + tuple( + [ + OrderingExpression(ex.DerefOp(ref.id.local_normalized)) + for ref in by_column_ids + ] + ), total_ordering_columns=frozenset( [ex.DerefOp(ref.id.local_normalized) for ref in by_column_ids] ), @@ -266,31 +271,26 @@ def peek_sql(self, n: int): def to_sql( self, offset_column: typing.Optional[str] = None, - col_id_overrides: typing.Mapping[str, str] = {}, ordered: bool = False, ) -> str: if offset_column or ordered: raise ValueError("Cannot produce sorted sql in partial ordering mode") - sql = ibis_bigquery.Backend().compile( - self._to_ibis_expr( - col_id_overrides=col_id_overrides, - ) - ) + sql = ibis_bigquery.Backend().compile(self._to_ibis_expr()) return typing.cast(str, sql) - def row_count(self) -> OrderedIR: + def row_count(self, name: str) -> OrderedIR: original_table = self._to_ibis_expr() ibis_table = original_table.agg( [ - original_table.count().name("count"), + original_table.count().name(name), ] ) return OrderedIR( ibis_table, - (ibis_table["count"],), + (ibis_table[name],), ordering=TotalOrdering( - ordering_value_columns=(ascending_over("count"),), - total_ordering_columns=frozenset([ex.deref("count")]), + ordering_value_columns=(ascending_over(name),), + total_ordering_columns=frozenset([ex.deref(name)]), ), ) @@ -299,7 +299,6 @@ def _to_ibis_expr( *, expose_hidden_cols: bool = False, fraction: Optional[float] = None, - col_id_overrides: typing.Mapping[str, str] = {}, ): """ Creates an Ibis table expression representing the DataFrame. @@ -320,8 +319,6 @@ def _to_ibis_expr( If True, include the hidden ordering columns in the results. Only compatible with `order_by` and `unordered` ``ordering_mode``. - col_id_overrides: - overrides the column ids for the result Returns: An ibis expression representing the data help by the ArrayValue object. """ @@ -346,10 +343,6 @@ def _to_ibis_expr( if self._reduced_predicate is not None: table = table.filter(base_table[PREDICATE_COLUMN]) table = table.drop(*columns_to_drop) - if col_id_overrides: - table = table.rename( - {value: key for key, value in col_id_overrides.items()} - ) if fraction is not None: table = table.filter(ibis.random() < ibis.literal(fraction)) return table @@ -941,7 +934,6 @@ def _reproject_to_table(self) -> OrderedIR: def to_sql( self, - col_id_overrides: typing.Mapping[str, str] = {}, ordered: bool = False, limit: Optional[int] = None, ) -> str: @@ -951,17 +943,13 @@ def to_sql( sql = ibis_bigquery.Backend().compile( baked_ir._to_ibis_expr( ordering_mode="unordered", - col_id_overrides=col_id_overrides, expose_hidden_cols=True, ) ) - output_columns = [ - col_id_overrides.get(col, col) for col in baked_ir.column_ids - ] sql = ( bigframes.core.compile.googlesql.Select() .from_(sql) - .select(output_columns) + .select(self.column_ids) .sql() ) @@ -979,7 +967,6 @@ def to_sql( sql = ibis_bigquery.Backend().compile( self._to_ibis_expr( ordering_mode="unordered", - col_id_overrides=col_id_overrides, expose_hidden_cols=False, ) ) @@ -987,16 +974,19 @@ def to_sql( def raw_sql_and_schema( self, + column_ids: typing.Sequence[str], ) -> typing.Tuple[str, typing.Sequence[google.cloud.bigquery.SchemaField]]: """Return sql with all hidden columns. Used to cache with ordering information. Also returns schema, as the extra ordering columns are determined compile-time. """ + col_id_overrides = dict(zip(self.column_ids, column_ids)) all_columns = (*self.column_ids, *self._hidden_ordering_column_names.keys()) as_ibis = self._to_ibis_expr( ordering_mode="unordered", expose_hidden_cols=True, - ).select(all_columns) + ) + as_ibis = as_ibis.select(all_columns).rename(col_id_overrides) # Ibis will produce non-nullable schema types, but bigframes should always be nullable fixed_ibis_schema = ibis_schema.Schema.from_tuples( @@ -1013,7 +1003,6 @@ def _to_ibis_expr( *, expose_hidden_cols: bool = False, fraction: Optional[float] = None, - col_id_overrides: typing.Mapping[str, str] = {}, ordering_mode: Literal["string_encoded", "unordered"], order_col_name: Optional[str] = ORDER_ID_COLUMN, ): @@ -1043,8 +1032,6 @@ def _to_ibis_expr( order_col_name: If the ordering mode outputs a single ordering or offsets column, use this as the column name. - col_id_overrides: - overrides the column ids for the result Returns: An ibis expression representing the data help by the ArrayValue object. """ @@ -1086,10 +1073,6 @@ def _to_ibis_expr( if self._reduced_predicate is not None: table = table.filter(base_table[PREDICATE_COLUMN]) table = table.drop(*columns_to_drop) - if col_id_overrides: - table = table.rename( - {value: key for key, value in col_id_overrides.items()} - ) if fraction is not None: table = table.filter(ibis.random() < ibis.literal(fraction)) return table diff --git a/bigframes/core/compile/compiler.py b/bigframes/core/compile/compiler.py index 1fa727780a..66fde9b874 100644 --- a/bigframes/core/compile/compiler.py +++ b/bigframes/core/compile/compiler.py @@ -18,6 +18,7 @@ import io import typing +import google.cloud.bigquery import ibis import ibis.backends import ibis.backends.bigquery @@ -32,6 +33,7 @@ import bigframes.core.compile.scalar_op_compiler as compile_scalar import bigframes.core.compile.schema_translator import bigframes.core.compile.single_column +import bigframes.core.expression as ex import bigframes.core.guid as guids import bigframes.core.identifiers as ids import bigframes.core.nodes as nodes @@ -50,31 +52,66 @@ class Compiler: strict: bool = True scalar_op_compiler = compile_scalar.ScalarOpCompiler() enable_pruning: bool = False + enable_densify_ids: bool = False + + def compile_sql( + self, node: nodes.BigFrameNode, ordered: bool, output_ids: typing.Sequence[str] + ) -> str: + node = self.set_output_names(node, output_ids) + if ordered: + node, limit = rewrites.pullup_limit_from_slice(node) + return self.compile_ordered_ir(self._preprocess(node)).to_sql( + ordered=True, limit=limit + ) + else: + return self.compile_unordered_ir(self._preprocess(node)).to_sql() + + def compile_peek_sql(self, node: nodes.BigFrameNode, n_rows: int) -> str: + return self.compile_unordered_ir(self._preprocess(node)).peek_sql(n_rows) + + def compile_raw( + self, + node: bigframes.core.nodes.BigFrameNode, + ) -> typing.Tuple[ + str, typing.Sequence[google.cloud.bigquery.SchemaField], bf_ordering.RowOrdering + ]: + ir = self.compile_ordered_ir(self._preprocess(node)) + sql, schema = ir.raw_sql_and_schema(column_ids=node.schema.names) + return sql, schema, ir._ordering def _preprocess(self, node: nodes.BigFrameNode): if self.enable_pruning: used_fields = frozenset(field.id for field in node.fields) node = node.prune(used_fields) node = functools.cache(rewrites.replace_slice_ops)(node) + if self.enable_densify_ids: + original_names = [id.name for id in node.ids] + node, _ = rewrites.remap_variables( + node, id_generator=ids.anonymous_serial_ids() + ) + node = self.set_output_names(node, original_names) return node - def compile_ordered_ir(self, node: nodes.BigFrameNode) -> compiled.OrderedIR: - ir = typing.cast( - compiled.OrderedIR, self.compile_node(self._preprocess(node), True) + def set_output_names( + self, node: bigframes.core.nodes.BigFrameNode, output_ids: typing.Sequence[str] + ): + # TODO: Create specialized output operators that will handle final names + return nodes.SelectionNode( + node, + tuple( + (ex.DerefOp(old_id), ids.ColumnId(out_id)) + for old_id, out_id in zip(node.ids, output_ids) + ), ) + + def compile_ordered_ir(self, node: nodes.BigFrameNode) -> compiled.OrderedIR: + ir = typing.cast(compiled.OrderedIR, self.compile_node(node, True)) if self.strict: assert ir.has_total_order return ir def compile_unordered_ir(self, node: nodes.BigFrameNode) -> compiled.UnorderedIR: - return typing.cast( - compiled.UnorderedIR, self.compile_node(self._preprocess(node), False) - ) - - def compile_peak_sql( - self, node: nodes.BigFrameNode, n_rows: int - ) -> typing.Optional[str]: - return self.compile_unordered_ir(self._preprocess(node)).peek_sql(n_rows) + return typing.cast(compiled.UnorderedIR, self.compile_node(node, False)) # TODO: Remove cache when schema no longer requires compilation to derive schema (and therefor only compiles for execution) @functools.lru_cache(maxsize=5000) @@ -144,11 +181,11 @@ def compile_fromrange(self, node: nodes.FromRangeNode, ordered: bool = True): labels_array_table = ibis.range( joined_table[start_column], joined_table[end_column] + node.step, node.step - ).name("labels") + ).name(node.output_id.sql) labels = ( typing.cast(ibis.expr.types.ArrayValue, labels_array_table) .as_table() - .unnest(["labels"]) + .unnest([node.output_id.sql]) ) if ordered: return compiled.OrderedIR( @@ -307,18 +344,19 @@ def compile_projection(self, node: nodes.ProjectionNode, ordered: bool = True): @_compile_node.register def compile_concat(self, node: nodes.ConcatNode, ordered: bool = True): + output_ids = [id.sql for id in node.output_ids] if ordered: compiled_ordered = [self.compile_ordered_ir(node) for node in node.children] - return concat_impl.concat_ordered(compiled_ordered) + return concat_impl.concat_ordered(compiled_ordered, output_ids) else: compiled_unordered = [ self.compile_unordered_ir(node) for node in node.children ] - return concat_impl.concat_unordered(compiled_unordered) + return concat_impl.concat_unordered(compiled_unordered, output_ids) @_compile_node.register def compile_rowcount(self, node: nodes.RowCountNode, ordered: bool = True): - result = self.compile_unordered_ir(node.child).row_count() + result = self.compile_unordered_ir(node.child).row_count(name=node.col_id.sql) return result if ordered else result.to_unordered() @_compile_node.register diff --git a/bigframes/core/compile/concat.py b/bigframes/core/compile/concat.py index 81d6805d22..ea4b59ca0b 100644 --- a/bigframes/core/compile/concat.py +++ b/bigframes/core/compile/concat.py @@ -32,6 +32,7 @@ def concat_unordered( items: typing.Sequence[compiled.UnorderedIR], + output_ids: typing.Sequence[str], ) -> compiled.UnorderedIR: """Append together multiple ArrayValue objects.""" if len(items) == 1: @@ -39,9 +40,8 @@ def concat_unordered( tables = [] for expr in items: table = expr._to_ibis_expr() - # Rename the value columns based on horizontal offset before applying union. table = table.select( - [table[col].name(f"column_{i}") for i, col in enumerate(table.columns)] + [table[col].name(id) for id, col in zip(output_ids, table.columns)] ) tables.append(table) combined_table = ibis.union(*tables) @@ -53,6 +53,7 @@ def concat_unordered( def concat_ordered( items: typing.Sequence[compiled.OrderedIR], + output_ids: typing.Sequence[str], ) -> compiled.OrderedIR: """Append together multiple ArrayValue objects.""" if len(items) == 1: @@ -67,19 +68,22 @@ def concat_ordered( ) for i, expr in enumerate(items): ordering_prefix = str(i).zfill(prefix_size) + renames = { + old_id: new_id for old_id, new_id in zip(expr.column_ids, output_ids) + } table = expr._to_ibis_expr( - ordering_mode="string_encoded", order_col_name=ORDER_ID_COLUMN + ordering_mode="string_encoded", + order_col_name=ORDER_ID_COLUMN, ) - # Rename the value columns based on horizontal offset before applying union. table = table.select( [ - table[col].name(f"column_{i}") + table[col].name(renames[col]) if col != ORDER_ID_COLUMN else ( ordering_prefix + reencode_order_string(table[ORDER_ID_COLUMN], max_encoding_size) ).name(ORDER_ID_COLUMN) - for i, col in enumerate(table.columns) + for col in table.columns ] ) tables.append(table) diff --git a/bigframes/core/compile/single_column.py b/bigframes/core/compile/single_column.py index 6f2f3f5b6e..2ec0796760 100644 --- a/bigframes/core/compile/single_column.py +++ b/bigframes/core/compile/single_column.py @@ -55,6 +55,7 @@ def join_by_column_ordered( l_value_mapping = dict(zip(left.column_ids, left.column_ids)) r_value_mapping = dict(zip(right.column_ids, right.column_ids)) + # hidden columns aren't necessarily unique, so need to remap to guids l_hidden_mapping = { id: guids.generate_guid("hidden_") for id in left._hidden_column_ids } @@ -68,12 +69,14 @@ def join_by_column_ordered( left_table = left._to_ibis_expr( ordering_mode="unordered", expose_hidden_cols=True, - col_id_overrides=l_mapping, ) + left_table = left_table.rename({val: key for key, val in l_hidden_mapping.items()}) right_table = right._to_ibis_expr( ordering_mode="unordered", expose_hidden_cols=True, - col_id_overrides=r_mapping, + ) + right_table = right_table.rename( + {val: key for key, val in r_hidden_mapping.items()} ) join_conditions = [ value_to_join_key(left_table[l_mapping[left_index]]) diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 9dee599a7c..3b7828bbf0 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -18,7 +18,7 @@ import dataclasses import itertools import typing -from typing import Mapping, Union +from typing import Mapping, TypeVar, Union import bigframes.core.identifiers as ids import bigframes.dtypes as dtypes @@ -56,6 +56,14 @@ def output_type( def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: return () + @abc.abstractmethod + def remap_column_refs( + self, + name_mapping: Mapping[ids.ColumnId, ids.ColumnId], + allow_partial_bindings: bool = False, + ) -> Aggregation: + ... + @dataclasses.dataclass(frozen=True) class NullaryAggregation(Aggregation): @@ -66,6 +74,13 @@ def output_type( ) -> dtypes.ExpressionType: return self.op.output_type() + def remap_column_refs( + self, + name_mapping: Mapping[ids.ColumnId, ids.ColumnId], + allow_partial_bindings: bool = False, + ) -> NullaryAggregation: + return self + @dataclasses.dataclass(frozen=True) class UnaryAggregation(Aggregation): @@ -81,6 +96,18 @@ def output_type( def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: return self.arg.column_references + def remap_column_refs( + self, + name_mapping: Mapping[ids.ColumnId, ids.ColumnId], + allow_partial_bindings: bool = False, + ) -> UnaryAggregation: + return UnaryAggregation( + self.op, + self.arg.remap_column_refs( + name_mapping, allow_partial_bindings=allow_partial_bindings + ), + ) + @dataclasses.dataclass(frozen=True) class BinaryAggregation(Aggregation): @@ -99,6 +126,24 @@ def output_type( def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: return (*self.left.column_references, *self.right.column_references) + def remap_column_refs( + self, + name_mapping: Mapping[ids.ColumnId, ids.ColumnId], + allow_partial_bindings: bool = False, + ) -> BinaryAggregation: + return BinaryAggregation( + self.op, + self.left.remap_column_refs( + name_mapping, allow_partial_bindings=allow_partial_bindings + ), + self.right.remap_column_refs( + name_mapping, allow_partial_bindings=allow_partial_bindings + ), + ) + + +TExpression = TypeVar("TExpression", bound="Expression") + @dataclasses.dataclass(frozen=True) class Expression(abc.ABC): @@ -109,14 +154,18 @@ def free_variables(self) -> typing.Tuple[str, ...]: return () @property + @abc.abstractmethod def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: - return () + ... def remap_column_refs( - self, name_mapping: Mapping[ids.ColumnId, ids.ColumnId] - ) -> Expression: + self: TExpression, + name_mapping: Mapping[ids.ColumnId, ids.ColumnId], + allow_partial_bindings: bool = False, + ) -> TExpression: return self.bind_refs( - {old_id: DerefOp(new_id) for old_id, new_id in name_mapping.items()} + {old_id: DerefOp(new_id) for old_id, new_id in name_mapping.items()}, # type: ignore + allow_partial_bindings=allow_partial_bindings, ) @property @@ -174,6 +223,10 @@ class ScalarConstantExpression(Expression): def is_const(self) -> bool: return True + @property + def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: + return () + def output_type( self, input_types: dict[ids.ColumnId, bigframes.dtypes.Dtype] ) -> dtypes.ExpressionType: @@ -211,6 +264,10 @@ def free_variables(self) -> typing.Tuple[str, ...]: def is_const(self) -> bool: return False + @property + def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: + return () + def output_type( self, input_types: dict[ids.ColumnId, bigframes.dtypes.Dtype] ) -> dtypes.ExpressionType: diff --git a/bigframes/core/identifiers.py b/bigframes/core/identifiers.py index 0d2aaeb07c..8c2f7e910f 100644 --- a/bigframes/core/identifiers.py +++ b/bigframes/core/identifiers.py @@ -15,13 +15,14 @@ import dataclasses import functools +import itertools from typing import Generator -def standard_identifiers() -> Generator[str, None, None]: +def standard_id_strings(prefix: str = "col_") -> Generator[str, None, None]: i = 0 while True: - yield f"col_{i}" + yield f"{prefix}{i}" i = i + 1 @@ -44,4 +45,28 @@ def local_normalized(self) -> ColumnId: return self # == ColumnId(name=self.sql) def __lt__(self, other: ColumnId) -> bool: - return self.name < other.name + return self.sql < other.sql + + +@dataclasses.dataclass(frozen=True) +class SerialColumnId(ColumnId): + """Id that is assigned a unique serial within the tree.""" + + name: str + id: int + + @property + def sql(self) -> str: + """Returns the unescaped SQL name.""" + return f"{self.name}_{self.id}" + + @property + def local_normalized(self) -> ColumnId: + """For use in compiler only. Normalizes to ColumnId referring to sql name.""" + return ColumnId(name=self.sql) + + +# TODO: Create serial ids locally, so can preserve name info +def anonymous_serial_ids() -> Generator[ColumnId, None, None]: + for i in itertools.count(): + yield SerialColumnId("uid", i) diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 2e23f529e2..30a130bbac 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -20,7 +20,7 @@ import functools import itertools import typing -from typing import Callable, cast, Iterable, Optional, Sequence, Tuple +from typing import Callable, cast, Iterable, Mapping, Optional, Sequence, Tuple import google.cloud.bigquery as bq @@ -88,6 +88,19 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]: def row_count(self) -> typing.Optional[int]: return None + @abc.abstractmethod + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + """Remap variable references""" + ... + + @property + @abc.abstractmethod + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + """The variables defined in this node (as opposed to by child nodes).""" + ... + @functools.cached_property def session(self): sessions = [] @@ -101,6 +114,17 @@ def session(self): return sessions[0] return None + def _validate(self): + """Validate the local data in the node.""" + return + + @functools.cache + def validate_tree(self) -> bool: + for child in self.child_nodes: + child.validate_tree() + self._validate() + return True + def _as_tuple(self) -> Tuple: """Get all fields as tuple.""" return tuple(getattr(self, field.name) for field in fields(self)) @@ -141,6 +165,7 @@ def fields(self) -> Iterable[Field]: @property def ids(self) -> Iterable[bfet_ids.ColumnId]: + """All output ids from the node.""" return (field.id for field in self.fields) @property @@ -220,6 +245,13 @@ def transform_children( """Apply a function to each child node.""" ... + @abc.abstractmethod + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + """Remap defined (in this node only) variables.""" + ... + @property def defines_namespace(self) -> bool: """ @@ -330,6 +362,18 @@ def row_count(self) -> typing.Optional[int]: (self.start, self.stop, self.step), child_length ) + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return () + + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return self + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + return self + @dataclass(frozen=True, eq=False) class JoinNode(BigFrameNode): @@ -338,7 +382,7 @@ class JoinNode(BigFrameNode): conditions: typing.Tuple[typing.Tuple[ex.DerefOp, ex.DerefOp], ...] type: typing.Literal["inner", "outer", "left", "right", "cross"] - def __post_init__(self): + def _validate(self): assert not ( set(self.left_child.ids) & set(self.right_child.ids) ), "Join ids collide" @@ -386,6 +430,10 @@ def row_count(self) -> Optional[int]: return None + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return () + def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] ) -> BigFrameNode: @@ -397,24 +445,38 @@ def transform_children( return self return transformed - @property - def defines_namespace(self) -> bool: - return True - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: # If this is a cross join, make sure to select at least one column from each side - new_used = used_cols.union( + condition_cols = used_cols.union( map(lambda x: x.id, itertools.chain.from_iterable(self.conditions)) ) - return self.transform_children(lambda x: x.prune(new_used)) + return self.transform_children( + lambda x: x.prune(frozenset([*condition_cols, *used_cols])) + ) + + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return self + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + new_conds = tuple( + ( + l_cond.remap_column_refs(mappings, allow_partial_bindings=True), + r_cond.remap_column_refs(mappings, allow_partial_bindings=True), + ) + for l_cond, r_cond in self.conditions + ) + return replace(self, conditions=new_conds) # type: ignore @dataclass(frozen=True, eq=False) class ConcatNode(BigFrameNode): # TODO: Explcitly map column ids from each child children: Tuple[BigFrameNode, ...] + output_ids: Tuple[bfet_ids.ColumnId, ...] - def __post_init__(self): + def _validate(self): if len(self.children) == 0: raise ValueError("Concat requires at least one input table. Zero provided.") child_schemas = [child.schema.dtypes for child in self.children] @@ -438,8 +500,8 @@ def explicitly_ordered(self) -> bool: def fields(self) -> Iterable[Field]: # TODO: Output names should probably be aligned beforehand or be part of concat definition return ( - Field(bfet_ids.ColumnId(f"column_{i}"), field.dtype) - for i, field in enumerate(self.children[0].fields) + Field(id, field.dtype) + for id, field in zip(self.output_ids, self.children[0].fields) ) @functools.cached_property @@ -457,6 +519,10 @@ def row_count(self) -> Optional[int]: total += count return total + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return self.output_ids + def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] ) -> BigFrameNode: @@ -470,6 +536,15 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: # TODO: Make concat prunable, probably by redefining return self + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + new_ids = tuple(mappings.get(id, id) for id in self.output_ids) + return replace(self, output_ids=new_ids) + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + return self + @dataclass(frozen=True, eq=False) class FromRangeNode(BigFrameNode): @@ -477,6 +552,7 @@ class FromRangeNode(BigFrameNode): start: BigFrameNode end: BigFrameNode step: int + output_id: bfet_ids.ColumnId = bfet_ids.ColumnId("labels") @property def roots(self) -> typing.Set[BigFrameNode]: @@ -496,9 +572,7 @@ def explicitly_ordered(self) -> bool: @functools.cached_property def fields(self) -> Iterable[Field]: - return ( - Field(bfet_ids.ColumnId("labels"), next(iter(self.start.fields)).dtype), - ) + return (Field(self.output_id, next(iter(self.start.fields)).dtype),) @functools.cached_property def variables_introduced(self) -> int: @@ -509,6 +583,14 @@ def variables_introduced(self) -> int: def row_count(self) -> Optional[int]: return None + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return (self.output_id,) + + @property + def defines_namespace(self) -> bool: + return True + def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] ) -> BigFrameNode: @@ -522,6 +604,14 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: # TODO: Make FromRangeNode prunable (or convert to other node types) return self + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return replace(self, output_id=mappings.get(self.output_id, self.output_id)) + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + return self + # Input Nodex # TODO: Most leaf nodes produce fixed column names based on the datasource @@ -595,9 +685,16 @@ def explicitly_ordered(self) -> bool: def row_count(self) -> typing.Optional[int]: return self.n_rows + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return tuple(item.id for item in self.scan_list.items) + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: + # Don't preoduce empty scan list no matter what, will result in broken sql syntax + # TODO: Handle more elegantly new_scan_list = ScanList( tuple(item for item in self.scan_list.items if item.id in used_cols) + or (self.scan_list.items[0],) ) return ReadLocalNode( self.feather_bytes, @@ -607,6 +704,20 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: self.session, ) + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + new_scan_list = ScanList( + tuple( + ScanItem(mappings.get(item.id, item.id), item.dtype, item.source_id) + for item in self.scan_list.items + ) + ) + return replace(self, scan_list=new_scan_list) + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + return self + @dataclass(frozen=True) class GbqTable: @@ -663,7 +774,7 @@ class ReadTableNode(LeafNode): table_session: bigframes.session.Session = field() - def __post_init__(self): + def _validate(self): # enforce invariants physical_names = set(map(lambda i: i.name, self.source.table.physical_schema)) if not set(scan.source_id for scan in self.scan_list.items).issubset( @@ -728,11 +839,30 @@ def row_count(self) -> typing.Optional[int]: return self.source.table.n_rows return None + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return tuple(item.id for item in self.scan_list.items) + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: new_scan_list = ScanList( tuple(item for item in self.scan_list.items if item.id in used_cols) + or (self.scan_list.items[0],) ) - return ReadTableNode(self.source, new_scan_list, self.table_session) + return replace(self, scan_list=new_scan_list) + + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + new_scan_list = ScanList( + tuple( + ScanItem(mappings.get(item.id, item.id), item.dtype, item.source_id) + for item in self.scan_list.items + ) + ) + return replace(self, scan_list=new_scan_list) + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + return self @dataclass(frozen=True, eq=False) @@ -741,14 +871,6 @@ class CachedTableNode(ReadTableNode): # note: this isn't a "child" node. original_node: BigFrameNode = field() - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - new_scan_list = ScanList( - tuple(item for item in self.scan_list.items if item.id in used_cols) - ) - return CachedTableNode( - self.source, new_scan_list, self.table_session, self.original_node - ) - # Unary nodes @dataclass(frozen=True, eq=False) @@ -777,6 +899,10 @@ def variables_introduced(self) -> int: def row_count(self) -> Optional[int]: return self.child.row_count + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return (self.col_id,) + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: if self.col_id not in used_cols: return self.child.prune(used_cols) @@ -784,6 +910,14 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: new_used = used_cols.difference([self.col_id]) return self.transform_children(lambda x: x.prune(new_used)) + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return replace(self, col_id=mappings.get(self.col_id, self.col_id)) + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + return self + @dataclass(frozen=True, eq=False) class FilterNode(UnaryNode): @@ -801,11 +935,28 @@ def variables_introduced(self) -> int: def row_count(self) -> Optional[int]: return None + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return () + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: consumed_ids = used_cols.union(self.predicate.column_references) pruned_child = self.child.prune(consumed_ids) return FilterNode(pruned_child, self.predicate) + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return self + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + return replace( + self, + predicate=self.predicate.remap_column_refs( + mappings, allow_partial_bindings=True + ), + ) + @dataclass(frozen=True, eq=False) class OrderByNode(UnaryNode): @@ -828,6 +979,10 @@ def explicitly_ordered(self) -> bool: def row_count(self) -> Optional[int]: return self.child.row_count + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return () + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: ordering_cols = itertools.chain.from_iterable( map(lambda x: x.referenced_columns, self.by) @@ -836,6 +991,25 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: pruned_child = self.child.prune(consumed_ids) return OrderByNode(pruned_child, self.by) + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return self + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + all_refs = set( + itertools.chain.from_iterable(map(lambda x: x.referenced_columns, self.by)) + ) + ref_mapping = {id: ex.DerefOp(mappings[id]) for id in all_refs} + new_by = cast( + tuple[OrderingExpression, ...], + tuple( + by_expr.bind_refs(ref_mapping, allow_partial_bindings=True) + for by_expr in self.by + ), + ) + return replace(self, by=new_by) + @dataclass(frozen=True, eq=False) class ReversedNode(UnaryNode): @@ -855,6 +1029,18 @@ def relation_ops_created(self) -> int: def row_count(self) -> Optional[int]: return self.child.row_count + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return () + + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return self + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + return self + @dataclass(frozen=True, eq=False) class SelectionNode(UnaryNode): @@ -862,6 +1048,11 @@ class SelectionNode(UnaryNode): typing.Tuple[ex.DerefOp, bigframes.core.identifiers.ColumnId], ... ] + def _validate(self): + for ref, _ in self.input_output_pairs: + if ref.id not in set(self.child.ids): + raise ValueError(f"Reference to column not in child: {ref.id}") + @functools.cached_property def fields(self) -> Iterable[Field]: return tuple( @@ -885,15 +1076,37 @@ def defines_namespace(self) -> bool: def row_count(self) -> Optional[int]: return self.child.row_count + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return tuple(id for _, id in self.input_output_pairs) + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - pruned_selections = tuple( - select for select in self.input_output_pairs if select[1] in used_cols + pruned_selections = ( + tuple( + select for select in self.input_output_pairs if select[1] in used_cols + ) + or self.input_output_pairs[:1] ) consumed_ids = frozenset(i[0].id for i in pruned_selections) pruned_child = self.child.prune(consumed_ids) return SelectionNode(pruned_child, pruned_selections) + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + new_pairs = tuple( + (ref, mappings.get(id, id)) for ref, id in self.input_output_pairs + ) + return replace(self, input_output_pairs=new_pairs) + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + new_fields = tuple( + (ex.remap_column_refs(mappings, allow_partial_bindings=True), id) + for ex, id in self.input_output_pairs + ) + return replace(self, input_output_pairs=new_fields) # type: ignore + @dataclass(frozen=True, eq=False) class ProjectionNode(UnaryNode): @@ -903,7 +1116,7 @@ class ProjectionNode(UnaryNode): typing.Tuple[ex.Expression, bigframes.core.identifiers.ColumnId], ... ] - def __post_init__(self): + def _validate(self): input_types = self.child._dtype_lookup for expression, id in self.assignments: # throws TypeError if invalid @@ -933,6 +1146,10 @@ def variables_introduced(self) -> int: def row_count(self) -> Optional[int]: return self.child.row_count + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return tuple(id for _, id in self.assignments) + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: pruned_assignments = tuple(i for i in self.assignments if i[1] in used_cols) if len(pruned_assignments) == 0: @@ -943,11 +1160,26 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: pruned_child = self.child.prune(used_cols.union(consumed_ids)) return ProjectionNode(pruned_child, pruned_assignments) + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + new_fields = tuple((ex, mappings.get(id, id)) for ex, id in self.assignments) + return replace(self, assignments=new_fields) + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + new_fields = tuple( + (ex.remap_column_refs(mappings, allow_partial_bindings=True), id) + for ex, id in self.assignments + ) + return replace(self, assignments=new_fields) + # TODO: Merge RowCount into Aggregate Node? # Row count can be compute from table metadata sometimes, so it is a bit special. @dataclass(frozen=True, eq=False) class RowCountNode(UnaryNode): + col_id: bfet_ids.ColumnId = bfet_ids.ColumnId("count") + @property def row_preserving(self) -> bool: return False @@ -958,7 +1190,7 @@ def non_local(self) -> bool: @property def fields(self) -> Iterable[Field]: - return (Field(bfet_ids.ColumnId("count"), bigframes.dtypes.INT_DTYPE),) + return (Field(self.col_id, bigframes.dtypes.INT_DTYPE),) @property def variables_introduced(self) -> int: @@ -972,6 +1204,22 @@ def defines_namespace(self) -> bool: def row_count(self) -> Optional[int]: return 1 + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return (self.col_id,) + + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return replace(self, col_id=mappings.get(self.col_id, self.col_id)) + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + return self + + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: + # TODO: Handle row count pruning + return self + @dataclass(frozen=True, eq=False) class AggregateNode(UnaryNode): @@ -1017,19 +1265,22 @@ def order_ambiguous(self) -> bool: def explicitly_ordered(self) -> bool: return True - @property - def defines_namespace(self) -> bool: - return True - @property def row_count(self) -> Optional[int]: if not self.by_column_ids: return 1 return None + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return tuple(id for _, id in self.aggregations) + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: by_ids = (ref.id for ref in self.by_column_ids) - pruned_aggs = tuple(agg for agg in self.aggregations if agg[1] in used_cols) + pruned_aggs = ( + tuple(agg for agg in self.aggregations if agg[1] in used_cols) + or self.aggregations[:1] + ) agg_inputs = itertools.chain.from_iterable( agg.column_references for agg, _ in pruned_aggs ) @@ -1037,6 +1288,20 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: pruned_child = self.child.prune(consumed_ids) return AggregateNode(pruned_child, pruned_aggs, self.by_column_ids, self.dropna) + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + new_aggs = tuple((agg, mappings.get(id, id)) for agg, id in self.aggregations) + return replace(self, aggregations=new_aggs) + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + new_aggs = tuple( + (agg.remap_column_refs(mappings, allow_partial_bindings=True), id) + for agg, id in self.aggregations + ) + new_by_ids = tuple(id.remap_column_refs(mappings) for id in self.by_column_ids) + return replace(self, by_column_ids=new_by_ids, aggregations=new_aggs) + @dataclass(frozen=True, eq=False) class WindowOpNode(UnaryNode): @@ -1074,14 +1339,38 @@ def added_field(self) -> Field: new_item_dtype = self.op.output_type(input_type) return Field(self.output_name, new_item_dtype) + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return (self.output_name,) + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: if self.output_name not in used_cols: - return self.child - consumed_ids = used_cols.difference([self.output_name]).union( - [self.column_name.id] + return self.child.prune(used_cols) + consumed_ids = ( + used_cols.difference([self.output_name]) + .union([self.column_name.id]) + .union(self.window_spec.all_referenced_columns) ) return self.transform_children(lambda x: x.prune(consumed_ids)) + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return replace( + self, output_name=mappings.get(self.output_name, self.output_name) + ) + + def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + return replace( + self, + column_name=self.column_name.remap_column_refs( + mappings, allow_partial_bindings=True + ), + window_spec=self.window_spec.remap_column_refs( + mappings, allow_partial_bindings=True + ), + ) + @dataclass(frozen=True, eq=False) class RandomSampleNode(UnaryNode): @@ -1103,6 +1392,20 @@ def variables_introduced(self) -> int: def row_count(self) -> Optional[int]: return None + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return () + + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return self + + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return self + # TODO: Explode should create a new column instead of overriding the existing one @dataclass(frozen=True, eq=False) @@ -1135,16 +1438,26 @@ def relation_ops_created(self) -> int: def variables_introduced(self) -> int: return len(self.column_ids) + 1 - @property - def defines_namespace(self) -> bool: - return True - @property def row_count(self) -> Optional[int]: return None + @property + def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: + return () + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: # Cannot prune explode op - return self.transform_children( - lambda x: x.prune(used_cols.union(ref.id for ref in self.column_ids)) - ) + consumed_ids = used_cols.union(ref.id for ref in self.column_ids) + return self.transform_children(lambda x: x.prune(consumed_ids)) + + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + return self + + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> BigFrameNode: + new_ids = tuple(id.remap_column_refs(mappings) for id in self.column_ids) + return replace(self, column_ids=new_ids) # type: ignore diff --git a/bigframes/core/ordering.py b/bigframes/core/ordering.py index 8bba7d72b6..acfb2adb3f 100644 --- a/bigframes/core/ordering.py +++ b/bigframes/core/ordering.py @@ -222,6 +222,11 @@ def _truncate_ordering( class TotalOrdering(RowOrdering): """Immutable object that holds information about the ordering of rows in a ArrayValue object. Guaranteed to be unambiguous.""" + def __post_init__(self): + assert set(ref.id for ref in self.total_ordering_columns).issubset( + self.referenced_columns + ) + # A table has a total ordering defined by the identities of a set of 1 or more columns. # These columns must always be part of the ordering, in order to guarantee that the ordering is total. # Therefore, any modifications(or drops) done to these columns must result in hidden copies being made. diff --git a/bigframes/core/rewrite.py b/bigframes/core/rewrite.py index 9c0eb81450..8187b16d87 100644 --- a/bigframes/core/rewrite.py +++ b/bigframes/core/rewrite.py @@ -16,7 +16,7 @@ import dataclasses import functools import itertools -from typing import cast, Mapping, Optional, Sequence, Tuple +from typing import cast, Generator, Mapping, Optional, Sequence, Tuple import bigframes.core.expression as scalar_exprs import bigframes.core.guid as guids @@ -578,3 +578,39 @@ def convert_complex_slice( ) conditions.append(step_cond) return merge_predicates(conditions) or scalar_exprs.const(True) + + +# TODO: May as well just outright remove selection nodes in this process. +def remap_variables( + root: nodes.BigFrameNode, id_generator: Generator[ids.ColumnId, None, None] +) -> Tuple[nodes.BigFrameNode, dict[ids.ColumnId, ids.ColumnId]]: + """ + Remap all variables in the BFET using the id_generator. + + Note: this will convert a DAG to a tree. + """ + child_replacement_map = dict() + ref_mapping = dict() + # Sequential ids are assigned bottom-up left-to-right + for child in root.child_nodes: + new_child, child_var_mapping = remap_variables(child, id_generator=id_generator) + child_replacement_map[child] = new_child + ref_mapping.update(child_var_mapping) + + # This is actually invalid until we've replaced all of children, refs and var defs + with_new_children = root.transform_children( + lambda node: child_replacement_map[node] + ) + + with_new_refs = with_new_children.remap_refs(ref_mapping) + + node_var_mapping = {old_id: next(id_generator) for old_id in root.node_defined_ids} + with_new_vars = with_new_refs.remap_vars(node_var_mapping) + with_new_vars._validate() + + return ( + with_new_vars, + node_var_mapping + if root.defines_namespace + else (ref_mapping | node_var_mapping), + ) diff --git a/bigframes/core/window_spec.py b/bigframes/core/window_spec.py index 2b9ff65084..d8098f18f7 100644 --- a/bigframes/core/window_spec.py +++ b/bigframes/core/window_spec.py @@ -15,7 +15,7 @@ from dataclasses import dataclass import itertools -from typing import Optional, Set, Tuple, Union +from typing import Mapping, Optional, Set, Tuple, Union import bigframes.core.expression as ex import bigframes.core.identifiers as ids @@ -180,3 +180,21 @@ def all_referenced_columns(self) -> Set[ids.ColumnId]: item.scalar_expression.column_references for item in self.ordering ) return set(itertools.chain((i.id for i in self.grouping_keys), ordering_vars)) + + def remap_column_refs( + self, + mapping: Mapping[ids.ColumnId, ids.ColumnId], + allow_partial_bindings: bool = False, + ) -> WindowSpec: + return WindowSpec( + grouping_keys=tuple( + key.remap_column_refs(mapping, allow_partial_bindings) + for key in self.grouping_keys + ), + ordering=tuple( + order_part.remap_column_refs(mapping, allow_partial_bindings) + for order_part in self.ordering + ), + bounds=self.bounds, + min_periods=self.min_periods, + )