-
Notifications
You must be signed in to change notification settings - Fork 49
refactor: Simplify join node definition #966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
31164aa
430d0ac
ded41b7
3ccd815
2d4e7cf
94508d5
1eb2e09
9f459dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,9 +17,8 @@ | |
import datetime | ||
import functools | ||
import io | ||
import itertools | ||
import typing | ||
from typing import Iterable, Optional, Sequence | ||
from typing import Iterable, Optional, Sequence, Tuple | ||
import warnings | ||
|
||
import google.cloud.bigquery | ||
|
@@ -191,19 +190,14 @@ def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue: | |
nodes.ConcatNode(children=tuple([self.node, *[val.node for val in other]])) | ||
) | ||
|
||
def project_to_id(self, expression: ex.Expression, output_id: str): | ||
def compute_values(self, assignments: Sequence[Tuple[ex.Expression, str]]): | ||
return ArrayValue( | ||
nodes.ProjectionNode( | ||
child=self.node, | ||
assignments=( | ||
( | ||
expression, | ||
output_id, | ||
), | ||
), | ||
) | ||
nodes.ProjectionNode(child=self.node, assignments=tuple(assignments)) | ||
) | ||
|
||
def project_to_id(self, expression: ex.Expression, output_id: str): | ||
return self.compute_values(((expression, output_id),)) | ||
|
||
def assign(self, source_id: str, destination_id: str) -> ArrayValue: | ||
if destination_id in self.column_ids: # Mutate case | ||
exprs = [ | ||
|
@@ -341,124 +335,33 @@ def _reproject_to_table(self) -> ArrayValue: | |
) | ||
) | ||
|
||
def unpivot( | ||
self, | ||
row_labels: typing.Sequence[typing.Hashable], | ||
unpivot_columns: typing.Sequence[ | ||
typing.Tuple[str, typing.Tuple[typing.Optional[str], ...]] | ||
], | ||
*, | ||
passthrough_columns: typing.Sequence[str] = (), | ||
index_col_ids: typing.Sequence[str] = ["index"], | ||
join_side: typing.Literal["left", "right"] = "left", | ||
) -> ArrayValue: | ||
""" | ||
Unpivot ArrayValue columns. | ||
|
||
Args: | ||
row_labels: Identifies the source of the row. Must be equal to length to source column list in unpivot_columns argument. | ||
unpivot_columns: Mapping of column id to list of input column ids. Lists of input columns may use None. | ||
passthrough_columns: Columns that will not be unpivoted. Column id will be preserved. | ||
index_col_id (str): The column id to be used for the row labels. | ||
|
||
Returns: | ||
ArrayValue: The unpivoted ArrayValue | ||
""" | ||
# There will be N labels, used to disambiguate which of N source columns produced each output row | ||
explode_offsets_id = bigframes.core.guid.generate_guid("unpivot_offsets_") | ||
labels_array = self._create_unpivot_labels_array( | ||
row_labels, index_col_ids, explode_offsets_id | ||
) | ||
|
||
# Unpivot creates N output rows for each input row, labels disambiguate these N rows | ||
joined_array = self._cross_join_w_labels(labels_array, join_side) | ||
|
||
# Build the output rows as a case statment that selects between the N input columns | ||
unpivot_exprs = [] | ||
# Supports producing multiple stacked ouput columns for stacking only part of hierarchical index | ||
for col_id, input_ids in unpivot_columns: | ||
# row explode offset used to choose the input column | ||
# we use offset instead of label as labels are not necessarily unique | ||
cases = itertools.chain( | ||
*( | ||
( | ||
ops.eq_op.as_expr(explode_offsets_id, ex.const(i)), | ||
ex.free_var(id_or_null) | ||
if (id_or_null is not None) | ||
else ex.const(None), | ||
) | ||
for i, id_or_null in enumerate(input_ids) | ||
) | ||
) | ||
col_expr = ops.case_when_op.as_expr(*cases) | ||
unpivot_exprs.append((col_expr, col_id)) | ||
|
||
unpivot_col_ids = [id for id, _ in unpivot_columns] | ||
return ArrayValue( | ||
nodes.ProjectionNode( | ||
child=joined_array.node, | ||
assignments=(*unpivot_exprs,), | ||
) | ||
).select_columns([*index_col_ids, *unpivot_col_ids, *passthrough_columns]) | ||
|
||
def _cross_join_w_labels( | ||
self, labels_array: ArrayValue, join_side: typing.Literal["left", "right"] | ||
) -> ArrayValue: | ||
""" | ||
Convert each row in self to N rows, one for each label in labels array. | ||
""" | ||
table_join_side = ( | ||
join_def.JoinSide.LEFT if join_side == "left" else join_def.JoinSide.RIGHT | ||
) | ||
labels_join_side = table_join_side.inverse() | ||
labels_mappings = tuple( | ||
join_def.JoinColumnMapping(labels_join_side, id, id) | ||
for id in labels_array.schema.names | ||
) | ||
table_mappings = tuple( | ||
join_def.JoinColumnMapping(table_join_side, id, id) | ||
for id in self.schema.names | ||
) | ||
join = join_def.JoinDefinition( | ||
conditions=(), mappings=(*labels_mappings, *table_mappings), type="cross" | ||
) | ||
if join_side == "left": | ||
joined_array = self.relational_join(labels_array, join_def=join) | ||
else: | ||
joined_array = labels_array.relational_join(self, join_def=join) | ||
return joined_array | ||
|
||
def _create_unpivot_labels_array( | ||
self, | ||
former_column_labels: typing.Sequence[typing.Hashable], | ||
col_ids: typing.Sequence[str], | ||
offsets_id: str, | ||
) -> ArrayValue: | ||
"""Create an ArrayValue from a list of label tuples.""" | ||
rows = [] | ||
for row_offset in range(len(former_column_labels)): | ||
row_label = former_column_labels[row_offset] | ||
row_label = (row_label,) if not isinstance(row_label, tuple) else row_label | ||
row = { | ||
col_ids[i]: (row_label[i] if pandas.notnull(row_label[i]) else None) | ||
for i in range(len(col_ids)) | ||
} | ||
row[offsets_id] = row_offset | ||
rows.append(row) | ||
|
||
return ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=self.session) | ||
|
||
def relational_join( | ||
self, | ||
other: ArrayValue, | ||
join_def: join_def.JoinDefinition, | ||
) -> ArrayValue: | ||
conditions: typing.Tuple[typing.Tuple[str, str], ...] = (), | ||
type: typing.Literal["inner", "outer", "left", "right", "cross"] = "inner", | ||
) -> typing.Tuple[ArrayValue, typing.Tuple[dict[str, str], dict[str, str]]]: | ||
join_node = nodes.JoinNode( | ||
left_child=self.node, | ||
right_child=other.node, | ||
join=join_def, | ||
conditions=conditions, | ||
type=type, | ||
) | ||
return ArrayValue(join_node) | ||
# Maps input ids to output ids for caller convenience | ||
l_size = len(self.node.schema) | ||
l_mapping = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm curious what the purpose of these mappings is? Could you give more explanation in a docstring, please? A guess: is it so we don't actually have to explicitly rename the columns in the SQL compilation step? If so, would it be better to switch to some offset-based logic now instead of mapping strings? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Callers used to provide the input_id->output_id mapping themselves through the join_def. I'm slowly taking power away from callers to provide the internal ids, so instead of accepting mappings from caller, this method now provides them to callers. I do want to eventually move to entirely offset-based column addressing, but its a multi-step process. |
||
lcol: ocol | ||
for lcol, ocol in zip( | ||
self.node.schema.names, join_node.schema.names[:l_size] | ||
) | ||
} | ||
r_mapping = { | ||
rcol: ocol | ||
for rcol, ocol in zip( | ||
other.node.schema.names, join_node.schema.names[l_size:] | ||
) | ||
} | ||
return ArrayValue(join_node), (l_mapping, r_mapping) | ||
|
||
def try_align_as_projection( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wikipedia calls these predicates, or more specifically "join predicates". That said, I do see Google SQL calls these join conditions.
Note: we will eventually want to support more than just equality, such as geospatial join predicates (https://carto.com/blog/guide-to-spatial-joins-and-predicates-with-sql), so
Tuple
doesn't seem like the right type.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wikipedia uses the term "condition" plenty as well - seems to be an accepted term. As for spatial predicates - can we leave those for later? Not sure how yet how I would want to represent those. I'm sure we will have one or two more refactors by then as we move towards offset-based indexing.