Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

refactor: define sqlglot compiler for readlocalnode #1344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
Loading
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions 18 bigframes/core/compile/sqlglot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from bigframes.core.compile.sqlglot.compiler import SQLGlotCompiler

__all__ = ["SQLGlotCompiler"]
216 changes: 216 additions & 0 deletions 216 bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import dataclasses
import functools
import io
import itertools
import typing

import pandas as pd
import pyarrow as pa
import pyarrow.feather as feather
import sqlglot as sg
import sqlglot.expressions as sge

import bigframes.core
from bigframes.core import utils
import bigframes.core.compile.sqlglot.scalar_op_compiler as scalar_op_compiler
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
import bigframes.core.expression as ex
import bigframes.core.guid as guid
import bigframes.core.identifiers as ids
import bigframes.core.nodes as nodes
import bigframes.core.ordering
import bigframes.core.rewrite
import bigframes.core.rewrite as rewrites
import bigframes.dtypes as dtypes
import bigframes.operations as ops
import bigframes.operations.aggregations as agg_ops


@dataclasses.dataclass(frozen=True)
class SQLGlotCompiler:
"""
Compiles BigFrameNode to SQLGlot expression tree recursively.
"""

# In strict mode, ordering will always be deterministic
# In unstrict mode, ordering from ReadTable or after joins may be ambiguous to improve query performance.
strict: bool = True
# Whether to always quote identifiers.
quoted: bool = True
# TODO: the way how the scalar operation compiles stop the non-recursive compiler.
# Define scalar compiler for converting bigframes expressions to sqlglot expressions.
scalar_op_compiler = scalar_op_compiler.SQLGlotScalarOpCompiler()

# TODO: add BigQuery Dialect
def compile_sql(
self,
node: nodes.BigFrameNode,
ordered: bool,
limit: typing.Optional[int] = None,
) -> sg.Expression:
# later steps might add ids, so snapshot before those steps.
output_ids = node.schema.names
if ordered:
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
node, pulled_up_limit = rewrites.pullup_limit_from_slice(node)
if (pulled_up_limit is not None) and (
(limit is None) or limit > pulled_up_limit
):
limit = pulled_up_limit

node = self._replace_unsupported_ops(node)
# prune before pulling up order to avoid unnnecessary row_number() ops
node = rewrites.column_pruning(node)
node, ordering = rewrites.pull_up_order(node, order_root=ordered)
# final pruning to cleanup up any leftovers unused values
node = rewrites.column_pruning(node)
# return self.compile_node(node).to_sql(
# order_by=ordering.all_ordering_columns if ordered else (),
# limit=limit,
# selections=output_ids,
# )

select_node = self.compile_node(node)

order_expr = self.compile_row_ordering(ordering)
if order_expr:
select_node = select_node.order_by(order_expr)

return select_node

def _replace_unsupported_ops(self, node: nodes.BigFrameNode):
# TODO: Run all replacement rules as single bottom-up pass
node = nodes.bottom_up(node, rewrites.rewrite_slice)
node = nodes.bottom_up(node, rewrites.rewrite_timedelta_expressions)
return node

def compile_row_ordering(self, node: bigframes.core.ordering.RowOrdering):
if len(node.all_ordering_columns) == 0:
return None

ordering_expr = [
sge.Ordered(
this=sge.Column(
this=sge.to_identifier(
col_ref.scalar_expression.id.sql, quoted=self.quoted
)
),
nulls_first=not col_ref.na_last,
desc=not col_ref.direction.is_ascending,
)
for col_ref in node.all_ordering_columns
]
return sge.Order(expressions=ordering_expr)

@functools.singledispatchmethod
def compile_node(self, node: nodes.BigFrameNode):
"""Defines transformation but isn't cached, always use compile_node instead"""
raise ValueError(f"Can't compile unrecognized node: {node}")

@compile_node.register
def compile_selection(self, node: nodes.SelectionNode):
child = self.compile_node(node.child)
selected_cols = [
sge.Alias(
this=self.scalar_op_compiler.compile_expression(expr),
alias=sge.to_identifier(id.name, quoted=self.quoted),
)
for expr, id in node.input_output_pairs
]
return child.select(*selected_cols, append=False)

@compile_node.register
def compile_projection(self, node: nodes.ProjectionNode):
child = self.compile_node(node.child)

new_cols = [
sge.Alias(
this=self.scalar_op_compiler.compile_expression(expr),
alias=sge.to_identifier(id.name, quoted=self.quoted),
)
for expr, id in node.assignments
]

return child.select(*new_cols, append=True)

@compile_node.register
def compile_readlocal(self, node: nodes.ReadLocalNode):
array_as_pd = pd.read_feather(
io.BytesIO(node.feather_bytes),
columns=[item.source_id for item in node.scan_list.items],
)
scan_list_items = node.scan_list.items

# In the order mode, adds the offset column containing the index (0 to N-1)
if node.offsets_col:
offsets_col_name = node.offsets_col.sql
array_as_pd[offsets_col_name] = range(len(array_as_pd))
scan_list_items = scan_list_items + (
nodes.ScanItem(
ids.ColumnId(offsets_col_name), dtypes.INT_DTYPE, offsets_col_name
),
)

# Convert timedeltas to microseconds for compatibility with BigQuery
_ = utils.replace_timedeltas_with_micros(array_as_pd)

array_expr = sge.DataType(
this=sge.DataType.Type.STRUCT,
expressions=[
sge.ColumnDef(
this=sge.to_identifier(item.source_id, quoted=self.quoted),
kind=sgt.SQLGlotType.from_bigframes_dtype(item.dtype),
)
for item in scan_list_items
],
nested=True,
)
array_values = [
sge.Tuple(
expressions=tuple(
self.literal(
value=value,
dtype=sgt.SQLGlotType.from_bigframes_dtype(item.dtype),
)
for value, item in zip(row, scan_list_items)
)
)
for _, row in array_as_pd.iterrows()
]
expr = sge.Unnest(
expressions=[
sge.DataType(
this=sge.DataType.Type.ARRAY,
expressions=[array_expr],
nested=True,
values=array_values,
),
],
)
return sg.select(sge.Star()).from_(expr)

def cast(self, arg, to) -> sge.Cast:
return sge.Cast(this=sge.convert(arg), to=to, copy=False)

def literal(self, value, dtype) -> sge.Expression:
if value is None:
return self.cast(sge.Null(), dtype)

# TODO: handle other types like visit_DefaultLiteral
return sge.convert(value)
83 changes: 83 additions & 0 deletions 83 bigframes/core/compile/sqlglot/scalar_op_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import dataclasses
import functools
import io
import itertools
import typing
from typing import cast, Sequence, Tuple, TYPE_CHECKING

import pandas as pd
import pyarrow as pa
import pyarrow.feather as feather
import sqlglot as sg
import sqlglot.expressions as sge

import bigframes.core
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
import bigframes.core.expression as ex
import bigframes.core.guid as guid
import bigframes.core.nodes as nodes
import bigframes.core.rewrite
import bigframes.dtypes as dtypes
import bigframes.operations as ops
import bigframes.operations.aggregations as agg_ops


@dataclasses.dataclass(frozen=True)
class SQLGlotScalarOpCompiler:
"""Scalar Op Compiler for converting BigFrames scalar op expressions to SQLGlot
expressions."""

# Whether to always quote identifiers.
quoted: bool = True

@functools.singledispatchmethod
def compile_expression(self, expr: ex.Expression):
raise NotImplementedError(f"Cannot compile expression: {expr}")

@compile_expression.register
def compile_deref_op(self, expr: ex.DerefOp):
return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=self.quoted))

@compile_expression.register
def compile_op_expression(self, expr: ex.OpExpression):
op = expr.op
# TODO: This can block non-recursively compiler.
args = tuple(map(self.compile_expression, expr.inputs))

op_name = expr.op.__class__.__name__
method_name = f"compile_{op_name}"
method = getattr(self, method_name, None)
if method is None:
raise NotImplementedError(f"Cannot compile operator {method_name}")

if isinstance(op, ops.UnaryOp):
return method(op, args[0])
elif isinstance(op, ops.BinaryOp):
return method(op, args[0], args[1])
elif isinstance(op, ops.TernaryOp):
return method(op, args[0], args[1], args[2])
elif isinstance(op, ops.NaryOp):
return method(op, *args)
else:
raise NotImplementedError(f"Cannot compile operator {method_name}")

# TODO: add parenthesize for operators
def compile_AddOp(self, op: ops.AddOp, left: sge.Expression, right: sge.Expression):
return sge.Add(this=left, expression=right)
Loading
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.