Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

refactor: Simplify projection nodes #961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 23 additions & 54 deletions 77 bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,49 +192,38 @@ def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue:
)

def project_to_id(self, expression: ex.Expression, output_id: str):
if output_id in self.column_ids: # Mutate case
exprs = [
((expression if (col_id == output_id) else ex.free_var(col_id)), col_id)
for col_id in self.column_ids
]
else: # append case
self_projection = (
(ex.free_var(col_id), col_id) for col_id in self.column_ids
)
exprs = [*self_projection, (expression, output_id)]
return ArrayValue(
nodes.ProjectionNode(
child=self.node,
assignments=tuple(exprs),
assignments=(
(
expression,
output_id,
),
),
)
)

def assign(self, source_id: str, destination_id: str) -> ArrayValue:
if destination_id in self.column_ids: # Mutate case
exprs = [
(
(
ex.free_var(source_id)
if (col_id == destination_id)
else ex.free_var(col_id)
),
(source_id if (col_id == destination_id) else col_id),
col_id,
)
for col_id in self.column_ids
]
else: # append case
self_projection = (
(ex.free_var(col_id), col_id) for col_id in self.column_ids
)
exprs = [*self_projection, (ex.free_var(source_id), destination_id)]
self_projection = ((col_id, col_id) for col_id in self.column_ids)
exprs = [*self_projection, (source_id, destination_id)]
return ArrayValue(
nodes.ProjectionNode(
nodes.SelectionNode(
child=self.node,
assignments=tuple(exprs),
input_output_pairs=tuple(exprs),
)
)

def assign_constant(
def create_constant(
self,
destination_id: str,
value: typing.Any,
Expand All @@ -244,49 +233,31 @@ def assign_constant(
# Need to assign a data type when value is NaN.
dtype = dtype or bigframes.dtypes.DEFAULT_DTYPE

if destination_id in self.column_ids: # Mutate case
exprs = [
(
(
ex.const(value, dtype)
if (col_id == destination_id)
else ex.free_var(col_id)
),
col_id,
)
for col_id in self.column_ids
]
else: # append case
self_projection = (
(ex.free_var(col_id), col_id) for col_id in self.column_ids
)
exprs = [*self_projection, (ex.const(value, dtype), destination_id)]
return ArrayValue(
nodes.ProjectionNode(
child=self.node,
assignments=tuple(exprs),
assignments=((ex.const(value, dtype), destination_id),),
)
)

def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
selections = ((ex.free_var(col_id), col_id) for col_id in column_ids)
# This basically just drops and reorders columns - logically a no-op except as a final step
selections = ((col_id, col_id) for col_id in column_ids)
return ArrayValue(
nodes.ProjectionNode(
nodes.SelectionNode(
child=self.node,
assignments=tuple(selections),
input_output_pairs=tuple(selections),
)
)

def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
new_projection = (
(ex.free_var(col_id), col_id)
for col_id in self.column_ids
if col_id not in columns
(col_id, col_id) for col_id in self.column_ids if col_id not in columns
)
return ArrayValue(
nodes.ProjectionNode(
nodes.SelectionNode(
child=self.node,
assignments=tuple(new_projection),
input_output_pairs=tuple(new_projection),
)
)

Expand Down Expand Up @@ -422,15 +393,13 @@ def unpivot(
col_expr = ops.case_when_op.as_expr(*cases)
unpivot_exprs.append((col_expr, col_id))

label_exprs = ((ex.free_var(id), id) for id in index_col_ids)
# passthrough columns are unchanged, just repeated N times each
passthrough_exprs = ((ex.free_var(id), id) for id in passthrough_columns)
unpivot_col_ids = [id for id, _ in unpivot_columns]
return ArrayValue(
nodes.ProjectionNode(
child=joined_array.node,
assignments=(*label_exprs, *unpivot_exprs, *passthrough_exprs),
assignments=(*unpivot_exprs,),
)
)
).select_columns([*index_col_ids, *unpivot_col_ids, *passthrough_columns])

def _cross_join_w_labels(
self, labels_array: ArrayValue, join_side: typing.Literal["left", "right"]
Expand Down
25 changes: 17 additions & 8 deletions 25 bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def multi_apply_unary_op(
for col_id in columns:
label = self.col_id_to_label[col_id]
block, result_id = block.project_expr(
expr.bind_all_variables({input_varname: ex.free_var(col_id)}),
expr.bind_variables({input_varname: ex.free_var(col_id)}),
label=label,
)
block = block.copy_values(result_id, col_id)
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def create_constant(
dtype: typing.Optional[bigframes.dtypes.Dtype] = None,
) -> typing.Tuple[Block, str]:
result_id = guid.generate_guid()
expr = self.expr.assign_constant(result_id, scalar_constant, dtype=dtype)
expr = self.expr.create_constant(result_id, scalar_constant, dtype=dtype)
# Create index copy with label inserted
# See: https://pandas.pydata.org/docs/reference/api/pandas.Index.insert.html
labels = self.column_labels.insert(len(self.column_labels), label)
Expand Down Expand Up @@ -1067,7 +1067,7 @@ def aggregate_all_and_stack(
index_id = guid.generate_guid()
result_expr = self.expr.aggregate(
aggregations, dropna=dropna
).assign_constant(index_id, None, None)
).create_constant(index_id, None, None)
# Transpose as last operation so that final block has valid transpose cache
return Block(
result_expr,
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def aggregate(
names: typing.List[Label] = []
if len(by_column_ids) == 0:
label_id = guid.generate_guid()
result_expr = result_expr.assign_constant(label_id, 0, pd.Int64Dtype())
result_expr = result_expr.create_constant(label_id, 0, pd.Int64Dtype())
index_columns = (label_id,)
names = [None]
else:
Expand Down Expand Up @@ -1614,17 +1614,22 @@ def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block:
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
if axis_number == 0:
expr = self._expr
new_index_cols = []
for index_col in self._index_columns:
new_col = guid.generate_guid()
expr = expr.project_to_id(
expression=ops.add_op.as_expr(
ex.const(prefix),
ops.AsTypeOp(to_type="string").as_expr(index_col),
),
output_id=index_col,
output_id=new_col,
)
new_index_cols.append(new_col)
expr = expr.select_columns((*new_index_cols, *self.value_columns))

return Block(
expr,
index_columns=self.index_columns,
index_columns=new_index_cols,
column_labels=self.column_labels,
index_labels=self.index.names,
)
Expand All @@ -1635,17 +1640,21 @@ def add_suffix(self, suffix: str, axis: str | int | None = None) -> Block:
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
if axis_number == 0:
expr = self._expr
new_index_cols = []
for index_col in self._index_columns:
new_col = guid.generate_guid()
expr = expr.project_to_id(
expression=ops.add_op.as_expr(
ops.AsTypeOp(to_type="string").as_expr(index_col),
ex.const(suffix),
),
output_id=index_col,
output_id=new_col,
)
new_index_cols.append(new_col)
expr = expr.select_columns((*new_index_cols, *self.value_columns))
return Block(
expr,
index_columns=self.index_columns,
index_columns=new_index_cols,
column_labels=self.column_labels,
index_labels=self.index.names,
)
Expand Down
15 changes: 14 additions & 1 deletion 15 bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,23 @@ def projection(
) -> T:
"""Apply an expression to the ArrayValue and assign the output to a column."""
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
values = [
new_values = [
op_compiler.compile_expression(expression, bindings).name(id)
for expression, id in expression_id_pairs
]
result = self._select(tuple([*self._columns, *new_values])) # type: ignore
return result

def selection(
self: T,
input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...],
) -> T:
"""Apply an expression to the ArrayValue and assign the output to a column."""
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
values = [
op_compiler.compile_expression(ex.free_var(input), bindings).name(id)
for input, id in input_output_pairs
]
result = self._select(tuple(values)) # type: ignore
return result

Expand Down
5 changes: 5 additions & 0 deletions 5 bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ def compile_reversed(self, node: nodes.ReversedNode, ordered: bool = True):
else:
return self.compile_unordered_ir(node.child)

@_compile_node.register
def compile_selection(self, node: nodes.SelectionNode, ordered: bool = True):
result = self.compile_node(node.child, ordered)
return result.selection(node.input_output_pairs)

@_compile_node.register
def compile_projection(self, node: nodes.ProjectionNode, ordered: bool = True):
result = self.compile_node(node.child, ordered)
Expand Down
29 changes: 22 additions & 7 deletions 29 bigframes/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,13 @@ def output_type(
...

@abc.abstractmethod
def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
"""Replace all variables with expression given in `bindings`."""
def bind_variables(
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
) -> Expression:
"""Replace variables with expression given in `bindings`.

If check_bind_all is True, validate that all free variables are bound to a new value.
"""
...

@property
Expand Down Expand Up @@ -141,7 +146,9 @@ def output_type(
) -> dtypes.ExpressionType:
return self.dtype

def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
def bind_variables(
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
) -> Expression:
return self

@property
Expand Down Expand Up @@ -178,11 +185,14 @@ def output_type(
else:
raise ValueError(f"Type of variable {self.id} has not been fixed.")

def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
def bind_variables(
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
) -> Expression:
if self.id in bindings.keys():
return bindings[self.id]
else:
elif check_bind_all:
raise ValueError(f"Variable {self.id} remains unbound")
return self

@property
def is_bijective(self) -> bool:
Expand Down Expand Up @@ -225,10 +235,15 @@ def output_type(
)
return self.op.output_type(*operand_types)

def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
def bind_variables(
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
) -> Expression:
return OpExpression(
self.op,
tuple(input.bind_all_variables(bindings) for input in self.inputs),
tuple(
input.bind_variables(bindings, check_bind_all=check_bind_all)
for input in self.inputs
),
)

@property
Expand Down
31 changes: 30 additions & 1 deletion 31 bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,15 +622,41 @@ def relation_ops_created(self) -> int:
return 0


@dataclass(frozen=True)
class SelectionNode(UnaryNode):
input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...]

def __hash__(self):
return self._node_hash

@functools.cached_property
def schema(self) -> schemata.ArraySchema:
input_types = self.child.schema._mapping
items = tuple(
schemata.SchemaItem(output, input_types[input])
for input, output in self.input_output_pairs
)
return schemata.ArraySchema(items)

@property
def variables_introduced(self) -> int:
# This operation only renames variables, doesn't actually create new ones
return 0


@dataclass(frozen=True)
class ProjectionNode(UnaryNode):
"""Assigns new variables (without modifying existing ones)"""

assignments: typing.Tuple[typing.Tuple[ex.Expression, str], ...]

def __post_init__(self):
input_types = self.child.schema._mapping
for expression, id in self.assignments:
# throws TypeError if invalid
_ = expression.output_type(input_types)
# Cannot assign to existing variables - append only!
assert all(name not in self.child.schema.names for _, name in self.assignments)

def __hash__(self):
return self._node_hash
Expand All @@ -644,7 +670,10 @@ def schema(self) -> schemata.ArraySchema:
)
for ex, id in self.assignments
)
return schemata.ArraySchema(items)
schema = self.child.schema
for item in items:
schema = schema.append(item)
return schema

@property
def variables_introduced(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion 2 bigframes/core/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def bind_variables(
self, mapping: Mapping[str, expression.Expression]
) -> OrderingExpression:
return OrderingExpression(
self.scalar_expression.bind_all_variables(mapping),
self.scalar_expression.bind_variables(mapping),
self.direction,
self.na_last,
)
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.