diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4fd6488c9c..2d11c951a1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,4 +39,5 @@ repos: hooks: - id: mypy additional_dependencies: [types-requests, types-tabulate, pandas-stubs] - args: ["--check-untyped-defs", "--explicit-package-bases", '--exclude="^third_party"', "--ignore-missing-imports"] + exclude: "^third_party" + args: ["--check-untyped-defs", "--explicit-package-bases", "--ignore-missing-imports"] diff --git a/CHANGELOG.md b/CHANGELOG.md index f3dae5af71..a989d8af66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,31 @@ [1]: https://pypi.org/project/bigframes/#history +## [1.17.0](https://github.com/googleapis/python-bigquery-dataframes/compare/v1.16.0...v1.17.0) (2024-09-11) + + +### Features + +* Add `__version__` alias to bigframes.pandas ([#967](https://github.com/googleapis/python-bigquery-dataframes/issues/967)) ([9ce10b4](https://github.com/googleapis/python-bigquery-dataframes/commit/9ce10b4248f106ac9e09fc0fe686cece86827337)) +* Add Gemini 1.5 stable models support ([#945](https://github.com/googleapis/python-bigquery-dataframes/issues/945)) ([c1cde19](https://github.com/googleapis/python-bigquery-dataframes/commit/c1cde19769c169b962b58b25f0be61c8c41edb95)) +* Allow setting table labels in `to_gbq` ([#941](https://github.com/googleapis/python-bigquery-dataframes/issues/941)) ([cccc6ca](https://github.com/googleapis/python-bigquery-dataframes/commit/cccc6ca8c1271097bbe15e3d9ccdcfd7c633227a)) +* Define list accessor for bigframes Series ([#946](https://github.com/googleapis/python-bigquery-dataframes/issues/946)) ([8e8279d](https://github.com/googleapis/python-bigquery-dataframes/commit/8e8279d4da90feb5766f266b49cb417f8cbec6c9)) +* Enable read_csv() to process other files ([#940](https://github.com/googleapis/python-bigquery-dataframes/issues/940)) ([3b35860](https://github.com/googleapis/python-bigquery-dataframes/commit/3b35860776033fc8e71e471422c6d2b9366a7c9f)) +* Include the bigframes package version alongside the feedback link in error messages ([#936](https://github.com/googleapis/python-bigquery-dataframes/issues/936)) ([7b59b6d](https://github.com/googleapis/python-bigquery-dataframes/commit/7b59b6dc6f0cedfee713b5b273d46fa84b70bfa4)) + + +### Bug Fixes + +* Astype Decimal to Int64 conversion. ([#957](https://github.com/googleapis/python-bigquery-dataframes/issues/957)) ([27764a6](https://github.com/googleapis/python-bigquery-dataframes/commit/27764a64f90092374458fafbe393bc6c30c85681)) +* Make `read_gbq_function` work for multi-param functions ([#947](https://github.com/googleapis/python-bigquery-dataframes/issues/947)) ([c750be6](https://github.com/googleapis/python-bigquery-dataframes/commit/c750be6093941677572a10c36a92984e954de32c)) +* Support `read_gbq_function` for axis=1 application ([#950](https://github.com/googleapis/python-bigquery-dataframes/issues/950)) ([86e54b1](https://github.com/googleapis/python-bigquery-dataframes/commit/86e54b13d2b91517b1df2d9c1f852a8e1925309a)) + + +### Documentation + +* Add docstring returns section to Options ([#937](https://github.com/googleapis/python-bigquery-dataframes/issues/937)) ([a2640a2](https://github.com/googleapis/python-bigquery-dataframes/commit/a2640a2d731c8d0aba1307311092f5e85b8ba077)) +* Update title of pypi notebook example to reflect use of the PyPI public dataset ([#952](https://github.com/googleapis/python-bigquery-dataframes/issues/952)) ([cd62e60](https://github.com/googleapis/python-bigquery-dataframes/commit/cd62e604967adac0c2f8600408bd9ce7886f2f98)) + ## [1.16.0](https://github.com/googleapis/python-bigquery-dataframes/compare/v1.15.0...v1.16.0) (2024-09-04) diff --git a/bigframes/_config/__init__.py b/bigframes/_config/__init__.py index c9b2a3f95a..ac58c19fa5 100644 --- a/bigframes/_config/__init__.py +++ b/bigframes/_config/__init__.py @@ -73,7 +73,12 @@ def _init_bigquery_thread_local(self): @property def bigquery(self) -> bigquery_options.BigQueryOptions: - """Options to use with the BigQuery engine.""" + """Options to use with the BigQuery engine. + + Returns: + bigframes._config.bigquery_options.BigQueryOptions: + Options for BigQuery engine. + """ if self._local.bigquery_options is not None: # The only way we can get here is if someone called # _init_bigquery_thread_local. @@ -83,7 +88,12 @@ def bigquery(self) -> bigquery_options.BigQueryOptions: @property def display(self) -> display_options.DisplayOptions: - """Options controlling object representation.""" + """Options controlling object representation. + + Returns: + bigframes._config.display_options.DisplayOptions: + Options for controlling object representation. + """ return self._local.display_options @property @@ -95,12 +105,21 @@ def sampling(self) -> sampling_options.SamplingOptions: (e.g., to_pandas, to_numpy, values) or implicitly (e.g., matplotlib plotting). This option can be overriden by parameters in specific functions. + + Returns: + bigframes._config.sampling_options.SamplingOptions: + Options for controlling downsampling. """ return self._local.sampling_options @property def compute(self) -> compute_options.ComputeOptions: - """Thread-local options controlling object computation.""" + """Thread-local options controlling object computation. + + Returns: + bigframes._config.compute_options.ComputeOptions: + Thread-local options for controlling object computation + """ return self._local.compute_options @property @@ -109,6 +128,11 @@ def is_bigquery_thread_local(self) -> bool: A thread-local session can be started by using `with bigframes.option_context("bigquery.some_option", "some-value"):`. + + Returns: + bool: + A boolean value, where a value is True if a thread-local session + is in use; otherwise False. """ return self._local.bigquery_options is not None diff --git a/bigframes/constants.py b/bigframes/constants.py index 3c18fd20bd..d6fe699713 100644 --- a/bigframes/constants.py +++ b/bigframes/constants.py @@ -21,6 +21,7 @@ import bigframes_vendored.constants +BF_VERSION = bigframes_vendored.constants.BF_VERSION FEEDBACK_LINK = bigframes_vendored.constants.FEEDBACK_LINK ABSTRACT_METHOD_ERROR_MESSAGE = ( bigframes_vendored.constants.ABSTRACT_METHOD_ERROR_MESSAGE diff --git a/bigframes/core/__init__.py b/bigframes/core/__init__.py index f3c75f7143..f65509e5b7 100644 --- a/bigframes/core/__init__.py +++ b/bigframes/core/__init__.py @@ -192,20 +192,15 @@ def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue: ) def project_to_id(self, expression: ex.Expression, output_id: str): - if output_id in self.column_ids: # Mutate case - exprs = [ - ((expression if (col_id == output_id) else ex.free_var(col_id)), col_id) - for col_id in self.column_ids - ] - else: # append case - self_projection = ( - (ex.free_var(col_id), col_id) for col_id in self.column_ids - ) - exprs = [*self_projection, (expression, output_id)] return ArrayValue( nodes.ProjectionNode( child=self.node, - assignments=tuple(exprs), + assignments=( + ( + expression, + output_id, + ), + ), ) ) @@ -213,28 +208,22 @@ def assign(self, source_id: str, destination_id: str) -> ArrayValue: if destination_id in self.column_ids: # Mutate case exprs = [ ( - ( - ex.free_var(source_id) - if (col_id == destination_id) - else ex.free_var(col_id) - ), + (source_id if (col_id == destination_id) else col_id), col_id, ) for col_id in self.column_ids ] else: # append case - self_projection = ( - (ex.free_var(col_id), col_id) for col_id in self.column_ids - ) - exprs = [*self_projection, (ex.free_var(source_id), destination_id)] + self_projection = ((col_id, col_id) for col_id in self.column_ids) + exprs = [*self_projection, (source_id, destination_id)] return ArrayValue( - nodes.ProjectionNode( + nodes.SelectionNode( child=self.node, - assignments=tuple(exprs), + input_output_pairs=tuple(exprs), ) ) - def assign_constant( + def create_constant( self, destination_id: str, value: typing.Any, @@ -244,49 +233,31 @@ def assign_constant( # Need to assign a data type when value is NaN. dtype = dtype or bigframes.dtypes.DEFAULT_DTYPE - if destination_id in self.column_ids: # Mutate case - exprs = [ - ( - ( - ex.const(value, dtype) - if (col_id == destination_id) - else ex.free_var(col_id) - ), - col_id, - ) - for col_id in self.column_ids - ] - else: # append case - self_projection = ( - (ex.free_var(col_id), col_id) for col_id in self.column_ids - ) - exprs = [*self_projection, (ex.const(value, dtype), destination_id)] return ArrayValue( nodes.ProjectionNode( child=self.node, - assignments=tuple(exprs), + assignments=((ex.const(value, dtype), destination_id),), ) ) def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue: - selections = ((ex.free_var(col_id), col_id) for col_id in column_ids) + # This basically just drops and reorders columns - logically a no-op except as a final step + selections = ((col_id, col_id) for col_id in column_ids) return ArrayValue( - nodes.ProjectionNode( + nodes.SelectionNode( child=self.node, - assignments=tuple(selections), + input_output_pairs=tuple(selections), ) ) def drop_columns(self, columns: Iterable[str]) -> ArrayValue: new_projection = ( - (ex.free_var(col_id), col_id) - for col_id in self.column_ids - if col_id not in columns + (col_id, col_id) for col_id in self.column_ids if col_id not in columns ) return ArrayValue( - nodes.ProjectionNode( + nodes.SelectionNode( child=self.node, - assignments=tuple(new_projection), + input_output_pairs=tuple(new_projection), ) ) @@ -422,15 +393,13 @@ def unpivot( col_expr = ops.case_when_op.as_expr(*cases) unpivot_exprs.append((col_expr, col_id)) - label_exprs = ((ex.free_var(id), id) for id in index_col_ids) - # passthrough columns are unchanged, just repeated N times each - passthrough_exprs = ((ex.free_var(id), id) for id in passthrough_columns) + unpivot_col_ids = [id for id, _ in unpivot_columns] return ArrayValue( nodes.ProjectionNode( child=joined_array.node, - assignments=(*label_exprs, *unpivot_exprs, *passthrough_exprs), + assignments=(*unpivot_exprs,), ) - ) + ).select_columns([*index_col_ids, *unpivot_col_ids, *passthrough_columns]) def _cross_join_w_labels( self, labels_array: ArrayValue, join_side: typing.Literal["left", "right"] diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index a309671842..4db171ec70 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -939,7 +939,7 @@ def multi_apply_unary_op( for col_id in columns: label = self.col_id_to_label[col_id] block, result_id = block.project_expr( - expr.bind_all_variables({input_varname: ex.free_var(col_id)}), + expr.bind_variables({input_varname: ex.free_var(col_id)}), label=label, ) block = block.copy_values(result_id, col_id) @@ -1006,7 +1006,7 @@ def create_constant( dtype: typing.Optional[bigframes.dtypes.Dtype] = None, ) -> typing.Tuple[Block, str]: result_id = guid.generate_guid() - expr = self.expr.assign_constant(result_id, scalar_constant, dtype=dtype) + expr = self.expr.create_constant(result_id, scalar_constant, dtype=dtype) # Create index copy with label inserted # See: https://pandas.pydata.org/docs/reference/api/pandas.Index.insert.html labels = self.column_labels.insert(len(self.column_labels), label) @@ -1067,7 +1067,7 @@ def aggregate_all_and_stack( index_id = guid.generate_guid() result_expr = self.expr.aggregate( aggregations, dropna=dropna - ).assign_constant(index_id, None, None) + ).create_constant(index_id, None, None) # Transpose as last operation so that final block has valid transpose cache return Block( result_expr, @@ -1222,7 +1222,7 @@ def aggregate( names: typing.List[Label] = [] if len(by_column_ids) == 0: label_id = guid.generate_guid() - result_expr = result_expr.assign_constant(label_id, 0, pd.Int64Dtype()) + result_expr = result_expr.create_constant(label_id, 0, pd.Int64Dtype()) index_columns = (label_id,) names = [None] else: @@ -1614,17 +1614,22 @@ def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block: axis_number = utils.get_axis_number("rows" if (axis is None) else axis) if axis_number == 0: expr = self._expr + new_index_cols = [] for index_col in self._index_columns: + new_col = guid.generate_guid() expr = expr.project_to_id( expression=ops.add_op.as_expr( ex.const(prefix), ops.AsTypeOp(to_type="string").as_expr(index_col), ), - output_id=index_col, + output_id=new_col, ) + new_index_cols.append(new_col) + expr = expr.select_columns((*new_index_cols, *self.value_columns)) + return Block( expr, - index_columns=self.index_columns, + index_columns=new_index_cols, column_labels=self.column_labels, index_labels=self.index.names, ) @@ -1635,17 +1640,21 @@ def add_suffix(self, suffix: str, axis: str | int | None = None) -> Block: axis_number = utils.get_axis_number("rows" if (axis is None) else axis) if axis_number == 0: expr = self._expr + new_index_cols = [] for index_col in self._index_columns: + new_col = guid.generate_guid() expr = expr.project_to_id( expression=ops.add_op.as_expr( ops.AsTypeOp(to_type="string").as_expr(index_col), ex.const(suffix), ), - output_id=index_col, + output_id=new_col, ) + new_index_cols.append(new_col) + expr = expr.select_columns((*new_index_cols, *self.value_columns)) return Block( expr, - index_columns=self.index_columns, + index_columns=new_index_cols, column_labels=self.column_labels, index_labels=self.index.names, ) @@ -2420,9 +2429,11 @@ def _is_monotonic( block, last_notna_id = self.apply_unary_op(column_ids[0], ops.notnull_op) for column_id in column_ids[1:]: block, notna_id = block.apply_unary_op(column_id, ops.notnull_op) + old_last_notna_id = last_notna_id block, last_notna_id = block.apply_binary_op( - last_notna_id, notna_id, ops.and_op + old_last_notna_id, notna_id, ops.and_op ) + block.drop_columns([notna_id, old_last_notna_id]) # loop over all columns to check monotonicity last_result_id = None @@ -2434,21 +2445,27 @@ def _is_monotonic( column_id, lag_result_id, ops.gt_op if increasing else ops.lt_op ) block, equal_id = block.apply_binary_op(column_id, lag_result_id, ops.eq_op) + block = block.drop_columns([lag_result_id]) if last_result_id is None: block, last_result_id = block.apply_binary_op( equal_id, strict_monotonic_id, ops.or_op ) - continue - block, equal_monotonic_id = block.apply_binary_op( - equal_id, last_result_id, ops.and_op - ) - block, last_result_id = block.apply_binary_op( - equal_monotonic_id, strict_monotonic_id, ops.or_op - ) + block = block.drop_columns([equal_id, strict_monotonic_id]) + else: + block, equal_monotonic_id = block.apply_binary_op( + equal_id, last_result_id, ops.and_op + ) + block = block.drop_columns([equal_id, last_result_id]) + block, last_result_id = block.apply_binary_op( + equal_monotonic_id, strict_monotonic_id, ops.or_op + ) + block = block.drop_columns([equal_monotonic_id, strict_monotonic_id]) block, monotonic_result_id = block.apply_binary_op( last_result_id, last_notna_id, ops.and_op # type: ignore ) + if last_result_id is not None: + block = block.drop_columns([last_result_id, last_notna_id]) result = block.get_stat(monotonic_result_id, agg_ops.all_op) self._stats_cache[column_name].update({op_name: result}) return result diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index 512238440c..9a9f598e89 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -134,10 +134,23 @@ def projection( ) -> T: """Apply an expression to the ArrayValue and assign the output to a column.""" bindings = {col: self._get_ibis_column(col) for col in self.column_ids} - values = [ + new_values = [ op_compiler.compile_expression(expression, bindings).name(id) for expression, id in expression_id_pairs ] + result = self._select(tuple([*self._columns, *new_values])) # type: ignore + return result + + def selection( + self: T, + input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...], + ) -> T: + """Apply an expression to the ArrayValue and assign the output to a column.""" + bindings = {col: self._get_ibis_column(col) for col in self.column_ids} + values = [ + op_compiler.compile_expression(ex.free_var(input), bindings).name(id) + for input, id in input_output_pairs + ] result = self._select(tuple(values)) # type: ignore return result diff --git a/bigframes/core/compile/compiler.py b/bigframes/core/compile/compiler.py index 3fedf5c0c8..80d5f5a893 100644 --- a/bigframes/core/compile/compiler.py +++ b/bigframes/core/compile/compiler.py @@ -264,6 +264,11 @@ def compile_reversed(self, node: nodes.ReversedNode, ordered: bool = True): else: return self.compile_unordered_ir(node.child) + @_compile_node.register + def compile_selection(self, node: nodes.SelectionNode, ordered: bool = True): + result = self.compile_node(node.child, ordered) + return result.selection(node.input_output_pairs) + @_compile_node.register def compile_projection(self, node: nodes.ProjectionNode, ordered: bool = True): result = self.compile_node(node.child, ordered) diff --git a/bigframes/core/compile/ibis_types.py b/bigframes/core/compile/ibis_types.py index 0b3038c9c7..f4ec295d5f 100644 --- a/bigframes/core/compile/ibis_types.py +++ b/bigframes/core/compile/ibis_types.py @@ -144,10 +144,12 @@ def cast_ibis_value( ), ibis_dtypes.Decimal(precision=38, scale=9): ( ibis_dtypes.float64, + ibis_dtypes.int64, ibis_dtypes.Decimal(precision=76, scale=38), ), ibis_dtypes.Decimal(precision=76, scale=38): ( ibis_dtypes.float64, + ibis_dtypes.int64, ibis_dtypes.Decimal(precision=38, scale=9), ), ibis_dtypes.time: ( diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index c216c29717..bbd23b689c 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -110,8 +110,13 @@ def output_type( ... @abc.abstractmethod - def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression: - """Replace all variables with expression given in `bindings`.""" + def bind_variables( + self, bindings: Mapping[str, Expression], check_bind_all: bool = True + ) -> Expression: + """Replace variables with expression given in `bindings`. + + If check_bind_all is True, validate that all free variables are bound to a new value. + """ ... @property @@ -141,7 +146,9 @@ def output_type( ) -> dtypes.ExpressionType: return self.dtype - def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression: + def bind_variables( + self, bindings: Mapping[str, Expression], check_bind_all: bool = True + ) -> Expression: return self @property @@ -178,11 +185,14 @@ def output_type( else: raise ValueError(f"Type of variable {self.id} has not been fixed.") - def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression: + def bind_variables( + self, bindings: Mapping[str, Expression], check_bind_all: bool = True + ) -> Expression: if self.id in bindings.keys(): return bindings[self.id] - else: + elif check_bind_all: raise ValueError(f"Variable {self.id} remains unbound") + return self @property def is_bijective(self) -> bool: @@ -225,10 +235,15 @@ def output_type( ) return self.op.output_type(*operand_types) - def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression: + def bind_variables( + self, bindings: Mapping[str, Expression], check_bind_all: bool = True + ) -> Expression: return OpExpression( self.op, - tuple(input.bind_all_variables(bindings) for input in self.inputs), + tuple( + input.bind_variables(bindings, check_bind_all=check_bind_all) + for input in self.inputs + ), ) @property diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 73780719a9..27e76c7910 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -622,8 +622,32 @@ def relation_ops_created(self) -> int: return 0 +@dataclass(frozen=True) +class SelectionNode(UnaryNode): + input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...] + + def __hash__(self): + return self._node_hash + + @functools.cached_property + def schema(self) -> schemata.ArraySchema: + input_types = self.child.schema._mapping + items = tuple( + schemata.SchemaItem(output, input_types[input]) + for input, output in self.input_output_pairs + ) + return schemata.ArraySchema(items) + + @property + def variables_introduced(self) -> int: + # This operation only renames variables, doesn't actually create new ones + return 0 + + @dataclass(frozen=True) class ProjectionNode(UnaryNode): + """Assigns new variables (without modifying existing ones)""" + assignments: typing.Tuple[typing.Tuple[ex.Expression, str], ...] def __post_init__(self): @@ -631,6 +655,8 @@ def __post_init__(self): for expression, id in self.assignments: # throws TypeError if invalid _ = expression.output_type(input_types) + # Cannot assign to existing variables - append only! + assert all(name not in self.child.schema.names for _, name in self.assignments) def __hash__(self): return self._node_hash @@ -644,7 +670,10 @@ def schema(self) -> schemata.ArraySchema: ) for ex, id in self.assignments ) - return schemata.ArraySchema(items) + schema = self.child.schema + for item in items: + schema = schema.append(item) + return schema @property def variables_introduced(self) -> int: diff --git a/bigframes/core/ordering.py b/bigframes/core/ordering.py index bff7e2ce44..a57d7a18d6 100644 --- a/bigframes/core/ordering.py +++ b/bigframes/core/ordering.py @@ -63,7 +63,7 @@ def bind_variables( self, mapping: Mapping[str, expression.Expression] ) -> OrderingExpression: return OrderingExpression( - self.scalar_expression.bind_all_variables(mapping), + self.scalar_expression.bind_variables(mapping), self.direction, self.na_last, ) diff --git a/bigframes/core/rewrite.py b/bigframes/core/rewrite.py index 60ed4069a9..0e73166ea5 100644 --- a/bigframes/core/rewrite.py +++ b/bigframes/core/rewrite.py @@ -27,6 +27,7 @@ Selection = Tuple[Tuple[scalar_exprs.Expression, str], ...] REWRITABLE_NODE_TYPES = ( + nodes.SelectionNode, nodes.ProjectionNode, nodes.FilterNode, nodes.ReversedNode, @@ -54,7 +55,12 @@ def from_node_span( for id in get_node_column_ids(node) ) return cls(node, selection, None, ()) - if isinstance(node, nodes.ProjectionNode): + + if isinstance(node, nodes.SelectionNode): + return cls.from_node_span(node.child, target).select( + node.input_output_pairs + ) + elif isinstance(node, nodes.ProjectionNode): return cls.from_node_span(node.child, target).project(node.assignments) elif isinstance(node, nodes.FilterNode): return cls.from_node_span(node.child, target).filter(node.predicate) @@ -69,22 +75,39 @@ def from_node_span( def column_lookup(self) -> Mapping[str, scalar_exprs.Expression]: return {col_id: expr for expr, col_id in self.columns} + def select(self, input_output_pairs: Tuple[Tuple[str, str], ...]) -> SquashedSelect: + new_columns = tuple( + ( + scalar_exprs.free_var(input).bind_variables(self.column_lookup), + output, + ) + for input, output in input_output_pairs + ) + return SquashedSelect( + self.root, new_columns, self.predicate, self.ordering, self.reverse_root + ) + def project( self, projection: Tuple[Tuple[scalar_exprs.Expression, str], ...] ) -> SquashedSelect: + existing_columns = self.columns new_columns = tuple( - (expr.bind_all_variables(self.column_lookup), id) for expr, id in projection + (expr.bind_variables(self.column_lookup), id) for expr, id in projection ) return SquashedSelect( - self.root, new_columns, self.predicate, self.ordering, self.reverse_root + self.root, + (*existing_columns, *new_columns), + self.predicate, + self.ordering, + self.reverse_root, ) def filter(self, predicate: scalar_exprs.Expression) -> SquashedSelect: if self.predicate is None: - new_predicate = predicate.bind_all_variables(self.column_lookup) + new_predicate = predicate.bind_variables(self.column_lookup) else: new_predicate = ops.and_op.as_expr( - self.predicate, predicate.bind_all_variables(self.column_lookup) + self.predicate, predicate.bind_variables(self.column_lookup) ) return SquashedSelect( self.root, self.columns, new_predicate, self.ordering, self.reverse_root @@ -204,7 +227,11 @@ def expand(self) -> nodes.BigFrameNode: root = nodes.FilterNode(child=root, predicate=self.predicate) if self.ordering: root = nodes.OrderByNode(child=root, by=self.ordering) - return nodes.ProjectionNode(child=root, assignments=self.columns) + selection = tuple((id, id) for _, id in self.columns) + return nodes.SelectionNode( + child=nodes.ProjectionNode(child=root, assignments=self.columns), + input_output_pairs=selection, + ) def join_as_projection( diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 6b782b4692..2ae6aefe1b 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -3027,6 +3027,7 @@ def to_gbq( index: bool = True, ordering_id: Optional[str] = None, clustering_columns: Union[pandas.Index, Iterable[typing.Hashable]] = (), + labels: dict[str, str] = {}, ) -> str: temp_table_ref = None @@ -3081,9 +3082,11 @@ def to_gbq( export_array, id_overrides = self._prepare_export( index=index and self._has_index, ordering_id=ordering_id ) - destination = bigquery.table.TableReference.from_string( - destination_table, - default_project=default_project, + destination: bigquery.table.TableReference = ( + bigquery.table.TableReference.from_string( + destination_table, + default_project=default_project, + ) ) _, query_job = self._session._export( export_array, @@ -3106,6 +3109,11 @@ def to_gbq( + constants.DEFAULT_EXPIRATION, ) + if len(labels) != 0: + table = bigquery.Table(result_table) + table.labels = labels + self._session.bqclient.update_table(table, ["labels"]) + return destination_table def to_numpy( diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index 45c1e7e4e2..bfed783e1e 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -189,18 +189,18 @@ class SimpleDtypeInfo: "binary[pyarrow]", ] -BOOL_BIGFRAMES_TYPES = [pd.BooleanDtype()] +BOOL_BIGFRAMES_TYPES = [BOOL_DTYPE] # Corresponds to the pandas concept of numeric type (such as when 'numeric_only' is specified in an operation) # Pandas is inconsistent, so two definitions are provided, each used in different contexts NUMERIC_BIGFRAMES_TYPES_RESTRICTIVE = [ - pd.Float64Dtype(), - pd.Int64Dtype(), + FLOAT_DTYPE, + INT_DTYPE, ] NUMERIC_BIGFRAMES_TYPES_PERMISSIVE = NUMERIC_BIGFRAMES_TYPES_RESTRICTIVE + [ - pd.BooleanDtype(), - pd.ArrowDtype(pa.decimal128(38, 9)), - pd.ArrowDtype(pa.decimal256(76, 38)), + BOOL_DTYPE, + NUMERIC_DTYPE, + BIGNUMERIC_DTYPE, ] @@ -308,10 +308,10 @@ def is_bool_coercable(type_: ExpressionType) -> bool: # special case - string[pyarrow] doesn't include the storage in its name, and both # "string" and "string[pyarrow]" are accepted -BIGFRAMES_STRING_TO_BIGFRAMES["string[pyarrow]"] = pd.StringDtype(storage="pyarrow") +BIGFRAMES_STRING_TO_BIGFRAMES["string[pyarrow]"] = STRING_DTYPE # special case - both "Int64" and "int64[pyarrow]" are accepted -BIGFRAMES_STRING_TO_BIGFRAMES["int64[pyarrow]"] = pd.Int64Dtype() +BIGFRAMES_STRING_TO_BIGFRAMES["int64[pyarrow]"] = INT_DTYPE # For the purposes of dataframe.memory_usage DTYPE_BYTE_SIZES = { @@ -552,14 +552,14 @@ def is_compatible(scalar: typing.Any, dtype: Dtype) -> typing.Optional[Dtype]: elif pd.api.types.is_numeric_dtype(dtype): # Implicit conversion currently only supported for numeric types if pd.api.types.is_bool(scalar): - return lcd_type(pd.BooleanDtype(), dtype) + return lcd_type(BOOL_DTYPE, dtype) if pd.api.types.is_float(scalar): - return lcd_type(pd.Float64Dtype(), dtype) + return lcd_type(FLOAT_DTYPE, dtype) if pd.api.types.is_integer(scalar): - return lcd_type(pd.Int64Dtype(), dtype) + return lcd_type(INT_DTYPE, dtype) if isinstance(scalar, decimal.Decimal): # TODO: Check context to see if can use NUMERIC instead of BIGNUMERIC - return lcd_type(pd.ArrowDtype(pa.decimal256(76, 38)), dtype) + return lcd_type(BIGNUMERIC_DTYPE, dtype) return None @@ -573,11 +573,11 @@ def lcd_type(*dtypes: Dtype) -> Dtype: return unique_dtypes.pop() # Implicit conversion currently only supported for numeric types hierarchy: list[Dtype] = [ - pd.BooleanDtype(), - pd.Int64Dtype(), - pd.ArrowDtype(pa.decimal128(38, 9)), - pd.ArrowDtype(pa.decimal256(76, 38)), - pd.Float64Dtype(), + BOOL_DTYPE, + INT_DTYPE, + NUMERIC_DTYPE, + BIGNUMERIC_DTYPE, + FLOAT_DTYPE, ] if any([dtype not in hierarchy for dtype in dtypes]): return None diff --git a/bigframes/functions/_remote_function_session.py b/bigframes/functions/_remote_function_session.py index 0ab19ca353..893b903aeb 100644 --- a/bigframes/functions/_remote_function_session.py +++ b/bigframes/functions/_remote_function_session.py @@ -176,7 +176,7 @@ def remote_function( getting and setting IAM roles on cloud resources. If this param is not provided then resource manager client from the session would be used. - dataset (str, Optional.): + dataset (str, Optional): Dataset in which to create a BigQuery remote function. It should be in `.` or `` format. If this parameter is not provided then session dataset id is used. @@ -387,7 +387,7 @@ def wrapper(func): # https://docs.python.org/3/library/inspect.html#inspect.signature signature_kwargs: Mapping[str, Any] = {"eval_str": True} else: - signature_kwargs = {} + signature_kwargs = {} # type: ignore signature = inspect.signature( func, diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index 7e9df74e76..39e3bfd8f0 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -14,6 +14,7 @@ from __future__ import annotations +import inspect import logging from typing import cast, Optional, TYPE_CHECKING import warnings @@ -107,6 +108,7 @@ def read_gbq_function( function_name: str, *, session: Session, + is_row_processor: bool = False, ): """ Read an existing BigQuery function and prepare it for use in future queries. @@ -149,6 +151,13 @@ def func(*ignored_args, **ignored_kwargs): expr = node(*ignored_args, **ignored_kwargs) # type: ignore return ibis_client.execute(expr) + func.__signature__ = inspect.signature(func).replace( # type: ignore + parameters=[ + inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) + for name in ibis_signature.parameter_names + ] + ) + # TODO: Move ibis logic to compiler step func.__name__ = routine_ref.routine_id @@ -186,5 +195,6 @@ def func(*ignored_args, **ignored_kwargs): func.output_dtype = bigframes.core.compile.ibis_types.ibis_dtype_to_bigframes_dtype( # type: ignore ibis_signature.output_type ) + func.is_row_processor = is_row_processor # type: ignore func.ibis_node = node # type: ignore return func diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 35bcf0a33c..a3cd065a55 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -55,10 +55,14 @@ _GEMINI_PRO_ENDPOINT = "gemini-pro" _GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514" _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514" +_GEMINI_1P5_PRO_001_ENDPOINT = "gemini-1.5-pro-001" +_GEMINI_1P5_FLASH_001_ENDPOINT = "gemini-1.5-flash-001" _GEMINI_ENDPOINTS = ( _GEMINI_PRO_ENDPOINT, _GEMINI_1P5_PRO_PREVIEW_ENDPOINT, _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT, + _GEMINI_1P5_PRO_001_ENDPOINT, + _GEMINI_1P5_FLASH_001_ENDPOINT, ) _CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet" @@ -728,7 +732,7 @@ class GeminiTextGenerator(base.BaseEstimator): Args: model_name (str, Default to "gemini-pro"): - The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514". Default to "gemini-pro". + The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514", "gemini-1.5-pro-001" and "gemini-1.5-flash-001". Default to "gemini-pro". .. note:: "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514" is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the @@ -750,7 +754,11 @@ def __init__( self, *, model_name: Literal[ - "gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514" + "gemini-pro", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-001", + "gemini-1.5-flash-001", ] = "gemini-pro", session: Optional[bigframes.Session] = None, connection_name: Optional[str] = None, diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index 7d75f4c65a..4e7e808260 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -63,6 +63,8 @@ llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator, llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, + llm._GEMINI_1P5_PRO_001_ENDPOINT: llm.GeminiTextGenerator, + llm._GEMINI_1P5_FLASH_001_ENDPOINT: llm.GeminiTextGenerator, llm._CLAUDE_3_HAIKU_ENDPOINT: llm.Claude3TextGenerator, llm._CLAUDE_3_SONNET_ENDPOINT: llm.Claude3TextGenerator, llm._CLAUDE_3_5_SONNET_ENDPOINT: llm.Claude3TextGenerator, diff --git a/bigframes/operations/_op_converters.py b/bigframes/operations/_op_converters.py new file mode 100644 index 0000000000..3ebf22bcb6 --- /dev/null +++ b/bigframes/operations/_op_converters.py @@ -0,0 +1,37 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bigframes.operations as ops + + +def convert_index(key: int) -> ops.ArrayIndexOp: + if key < 0: + raise NotImplementedError("Negative indexing is not supported.") + return ops.ArrayIndexOp(index=key) + + +def convert_slice(key: slice) -> ops.ArraySliceOp: + if key.step is not None and key.step != 1: + raise NotImplementedError(f"Only a step of 1 is allowed, got {key.step}") + + if (key.start is not None and key.start < 0) or ( + key.stop is not None and key.stop < 0 + ): + raise NotImplementedError("Slicing with negative numbers is not allowed.") + + return ops.ArraySliceOp( + start=key.start if key.start is not None else 0, + stop=key.stop, + step=key.step, + ) diff --git a/bigframes/operations/lists.py b/bigframes/operations/lists.py new file mode 100644 index 0000000000..16c22dfb2a --- /dev/null +++ b/bigframes/operations/lists.py @@ -0,0 +1,46 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from typing import Union + +import bigframes_vendored.pandas.core.arrays.arrow.accessors as vendoracessors + +from bigframes.core import log_adapter +import bigframes.operations as ops +from bigframes.operations._op_converters import convert_index, convert_slice +import bigframes.operations.base +import bigframes.series as series + + +@log_adapter.class_logger +class ListAccessor( + bigframes.operations.base.SeriesMethods, vendoracessors.ListAccessor +): + __doc__ = vendoracessors.ListAccessor.__doc__ + + def len(self): + return self._apply_unary_op(ops.len_op) + + def __getitem__(self, key: Union[int, slice]) -> series.Series: + if isinstance(key, int): + return self._apply_unary_op(convert_index(key)) + elif isinstance(key, slice): + return self._apply_unary_op(convert_slice(key)) + else: + raise ValueError(f"key must be an int or slice, got {type(key).__name__}") + + __getitem__.__doc__ = inspect.getdoc(vendoracessors.ListAccessor.__getitem__) diff --git a/bigframes/operations/strings.py b/bigframes/operations/strings.py index d3e9c7edc6..4af142e0d5 100644 --- a/bigframes/operations/strings.py +++ b/bigframes/operations/strings.py @@ -23,6 +23,7 @@ from bigframes.core import log_adapter import bigframes.dataframe as df import bigframes.operations as ops +from bigframes.operations._op_converters import convert_index, convert_slice import bigframes.operations.base import bigframes.series as series @@ -40,28 +41,9 @@ class StringMethods(bigframes.operations.base.SeriesMethods, vendorstr.StringMet def __getitem__(self, key: Union[int, slice]) -> series.Series: if isinstance(key, int): - if key < 0: - raise NotImplementedError("Negative indexing is not supported.") - return self._apply_unary_op(ops.ArrayIndexOp(index=key)) + return self._apply_unary_op(convert_index(key)) elif isinstance(key, slice): - if key.step is not None and key.step != 1: - raise NotImplementedError( - f"Only a step of 1 is allowed, got {key.step}" - ) - if (key.start is not None and key.start < 0) or ( - key.stop is not None and key.stop < 0 - ): - raise NotImplementedError( - "Slicing with negative numbers is not allowed." - ) - - return self._apply_unary_op( - ops.ArraySliceOp( - start=key.start if key.start is not None else 0, - stop=key.stop, - step=key.step, - ) - ) + return self._apply_unary_op(convert_slice(key)) else: raise ValueError(f"key must be an int or slice, got {type(key).__name__}") diff --git a/bigframes/pandas/__init__.py b/bigframes/pandas/__init__.py index 08d808572d..3809384c95 100644 --- a/bigframes/pandas/__init__.py +++ b/bigframes/pandas/__init__.py @@ -69,6 +69,7 @@ import bigframes.session import bigframes.session._io.bigquery import bigframes.session.clients +import bigframes.version try: import resource @@ -692,10 +693,11 @@ def remote_function( remote_function.__doc__ = inspect.getdoc(bigframes.session.Session.remote_function) -def read_gbq_function(function_name: str): +def read_gbq_function(function_name: str, is_row_processor: bool = False): return global_session.with_default_session( bigframes.session.Session.read_gbq_function, function_name=function_name, + is_row_processor=is_row_processor, ) @@ -837,6 +839,7 @@ def clean_up_by_session_id( Index = bigframes.core.indexes.Index MultiIndex = bigframes.core.indexes.MultiIndex Series = bigframes.series.Series +__version__ = bigframes.version.__version__ # Other public pandas attributes NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"]) @@ -910,6 +913,7 @@ def reset_session(): "Index", "MultiIndex", "Series", + "__version__", # Other public pandas attributes "NamedAgg", "options", diff --git a/bigframes/series.py b/bigframes/series.py index a166680f85..5192a9cf49 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -53,6 +53,7 @@ import bigframes.operations.aggregations as agg_ops import bigframes.operations.base import bigframes.operations.datetimes as dt +import bigframes.operations.lists as lists import bigframes.operations.plotting as plotting import bigframes.operations.strings as strings import bigframes.operations.structs as structs @@ -66,6 +67,8 @@ " Try converting it to a remote function." ) +_list = list # Type alias to escape Series.list property + @log_adapter.class_logger class Series(bigframes.operations.base.SeriesMethods, vendored_pandas_series.Series): @@ -161,6 +164,10 @@ def query_job(self) -> Optional[bigquery.QueryJob]: def struct(self) -> structs.StructAccessor: return structs.StructAccessor(self._block) + @property + def list(self) -> lists.ListAccessor: + return lists.ListAccessor(self._block) + @property @validations.requires_ordering() def T(self) -> Series: @@ -1708,7 +1715,7 @@ def to_latex( buf, columns=columns, header=header, index=index, **kwargs ) - def tolist(self) -> list: + def tolist(self) -> _list: return self.to_pandas().to_list() to_list = tolist diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index fba1d41e30..045483bd53 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -1008,10 +1008,12 @@ def _check_file_size(self, filepath: str): blob = bucket.blob(blob_name) blob.reload() file_size = blob.size - else: # local file path + elif os.path.exists(filepath): # local file path file_size = os.path.getsize(filepath) + else: + file_size = None - if file_size > max_size: + if file_size is not None and file_size > max_size: # Convert to GB file_size = round(file_size / (1024**3), 1) max_size = int(max_size / 1024**3) @@ -1223,6 +1225,7 @@ def remote_function( def read_gbq_function( self, function_name: str, + is_row_processor: bool = False, ): """Loads a BigQuery function from BigQuery. @@ -1239,12 +1242,22 @@ def read_gbq_function( **Examples:** - Use the ``cw_lower_case_ascii_only`` function from Community UDFs. - (https://github.com/GoogleCloudPlatform/bigquery-utils/blob/master/udfs/community/cw_lower_case_ascii_only.sqlx) - >>> import bigframes.pandas as bpd >>> bpd.options.display.progress_bar = None + Use the [cw_lower_case_ascii_only](https://github.com/GoogleCloudPlatform/bigquery-utils/blob/master/udfs/community/README.md#cw_lower_case_ascii_onlystr-string) + function from Community UDFs. + + >>> func = bpd.read_gbq_function("bqutil.fn.cw_lower_case_ascii_only") + + You can run it on scalar input. Usually you would do so to verify that + it works as expected before applying to all values in a Series. + + >>> func('AURÉLIE') + 'aurÉlie' + + You can apply it to a BigQuery DataFrames Series. + >>> df = bpd.DataFrame({'id': [1, 2, 3], 'name': ['AURÉLIE', 'CÉLESTINE', 'DAPHNÉ']}) >>> df id name @@ -1254,7 +1267,6 @@ def read_gbq_function( [3 rows x 2 columns] - >>> func = bpd.read_gbq_function("bqutil.fn.cw_lower_case_ascii_only") >>> df1 = df.assign(new_name=df['name'].apply(func)) >>> df1 id name new_name @@ -1264,13 +1276,45 @@ def read_gbq_function( [3 rows x 3 columns] + You can even use a function with multiple inputs. For example, + [cw_regexp_replace_5](https://github.com/GoogleCloudPlatform/bigquery-utils/blob/master/udfs/community/README.md#cw_regexp_replace_5haystack-string-regexp-string-replacement-string-offset-int64-occurrence-int64) + from Community UDFs. + + >>> func = bpd.read_gbq_function("bqutil.fn.cw_regexp_replace_5") + >>> func('TestStr123456', 'Str', 'Cad$', 1, 1) + 'TestCad$123456' + + >>> df = bpd.DataFrame({ + ... "haystack" : ["TestStr123456", "TestStr123456Str", "TestStr123456Str"], + ... "regexp" : ["Str", "Str", "Str"], + ... "replacement" : ["Cad$", "Cad$", "Cad$"], + ... "offset" : [1, 1, 1], + ... "occurrence" : [1, 2, 1] + ... }) + >>> df + haystack regexp replacement offset occurrence + 0 TestStr123456 Str Cad$ 1 1 + 1 TestStr123456Str Str Cad$ 1 2 + 2 TestStr123456Str Str Cad$ 1 1 + + [3 rows x 5 columns] + >>> df.apply(func, axis=1) + 0 TestCad$123456 + 1 TestStr123456Cad$ + 2 TestCad$123456Str + dtype: string + Args: function_name (str): - the function's name in BigQuery in the format + The function's name in BigQuery in the format `project_id.dataset_id.function_name`, or `dataset_id.function_name` to load from the default project, or `function_name` to load from the default project and the dataset associated with the current session. + is_row_processor (bool, default False): + Whether the function is a row processor. This is set to True + for a function which receives an entire row of a DataFrame as + a pandas Series. Returns: callable: A function object pointing to the BigQuery function read @@ -1284,6 +1328,7 @@ def read_gbq_function( return bigframes_rf.read_gbq_function( function_name=function_name, session=self, + is_row_processor=is_row_processor, ) def _prepare_copy_job_config(self) -> bigquery.CopyJobConfig: diff --git a/bigframes/session/executor.py b/bigframes/session/executor.py index 72d5493294..424e6d7dad 100644 --- a/bigframes/session/executor.py +++ b/bigframes/session/executor.py @@ -457,9 +457,7 @@ def generate_head_plan(node: nodes.BigFrameNode, n: int): predicate = ops.lt_op.as_expr(ex.free_var(offsets_id), ex.const(n)) plan_w_head = nodes.FilterNode(plan_w_offsets, predicate) # Finally, drop the offsets column - return nodes.ProjectionNode( - plan_w_head, tuple((ex.free_var(i), i) for i in node.schema.names) - ) + return nodes.SelectionNode(plan_w_head, tuple((i, i) for i in node.schema.names)) def generate_row_count_plan(node: nodes.BigFrameNode): diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index edfd57b965..924fddce12 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -18,6 +18,7 @@ import dataclasses import datetime import itertools +import os import typing from typing import Dict, Hashable, IO, Iterable, List, Optional, Sequence, Tuple, Union @@ -421,11 +422,16 @@ def _read_bigquery_load_job( load_job = self._bqclient.load_table_from_uri( filepath_or_buffer, table, job_config=job_config ) - else: + elif os.path.exists(filepath_or_buffer): # local file path with open(filepath_or_buffer, "rb") as source_file: load_job = self._bqclient.load_table_from_file( source_file, table, job_config=job_config ) + else: + raise NotImplementedError( + f"BigQuery engine only supports a local file path or GCS path. " + f"{constants.FEEDBACK_LINK}" + ) else: load_job = self._bqclient.load_table_from_file( filepath_or_buffer, table, job_config=job_config diff --git a/bigframes/session/planner.py b/bigframes/session/planner.py index 2a74521b43..bc640ec9fa 100644 --- a/bigframes/session/planner.py +++ b/bigframes/session/planner.py @@ -33,7 +33,7 @@ def session_aware_cache_plan( """ node_counts = traversals.count_nodes(session_forest) # These node types are cheap to re-compute, so it makes more sense to cache their children. - de_cachable_types = (nodes.FilterNode, nodes.ProjectionNode) + de_cachable_types = (nodes.FilterNode, nodes.ProjectionNode, nodes.SelectionNode) caching_target = cur_node = root caching_target_refs = node_counts.get(caching_target, 0) @@ -49,7 +49,15 @@ def session_aware_cache_plan( # Projection defines the variables that are used in the filter expressions, need to substitute variables with their scalar expressions # that instead reference variables in the child node. bindings = {name: expr for expr, name in cur_node.assignments} - filters = [i.bind_all_variables(bindings) for i in filters] + filters = [ + i.bind_variables(bindings, check_bind_all=False) for i in filters + ] + elif isinstance(cur_node, nodes.SelectionNode): + bindings = { + output: ex.free_var(input) + for input, output in cur_node.input_output_pairs + } + filters = [i.bind_variables(bindings) for i in filters] else: raise ValueError(f"Unexpected de-cached node: {cur_node}") diff --git a/bigframes/version.py b/bigframes/version.py index d5b4691b98..2c0c6e4d3a 100644 --- a/bigframes/version.py +++ b/bigframes/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.16.0" +__version__ = "1.17.0" diff --git a/docs/reference/bigframes.pandas/series.rst b/docs/reference/bigframes.pandas/series.rst index f14eb8e862..30cf851de7 100644 --- a/docs/reference/bigframes.pandas/series.rst +++ b/docs/reference/bigframes.pandas/series.rst @@ -35,6 +35,14 @@ String handling :inherited-members: :undoc-members: +List handling +^^^^^^^^^^^^^ + +.. automodule:: bigframes.operations.lists + :members: + :inherited-members: + :undoc-members: + Struct handling ^^^^^^^^^^^^^^^ diff --git a/notebooks/dataframes/pypi.ipynb b/notebooks/dataframes/pypi.ipynb index 3777e98d42..7b16412ff5 100644 --- a/notebooks/dataframes/pypi.ipynb +++ b/notebooks/dataframes/pypi.ipynb @@ -25,7 +25,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Analyzing Python dependencies with BigQuery DataFrames\n", + "# Analyzing package downloads from PyPI with BigQuery DataFrames\n", "\n", "In this notebook, you'll use the [PyPI public dataset](https://console.cloud.google.com/marketplace/product/gcp-public-data-pypi/pypi) and the [deps.dev public dataset](https://deps.dev/) to visualize Python package downloads for a package and its dependencies.\n", "\n", diff --git a/notebooks/dataframes/struct_and_array_dtypes.ipynb b/notebooks/dataframes/struct_and_array_dtypes.ipynb index 3bcdaf40f7..def65ee6ca 100644 --- a/notebooks/dataframes/struct_and_array_dtypes.ipynb +++ b/notebooks/dataframes/struct_and_array_dtypes.ipynb @@ -6,7 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Copyright 2023 Google LLC\n", + "# Copyright 2024 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -212,6 +212,54 @@ "cell_type": "code", "execution_count": 7, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 3\n", + "1 2\n", + "2 4\n", + "Name: Scores, dtype: Int64" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Find the length of each array with list accessor\n", + "df['Scores'].list.len()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 88\n", + "1 81\n", + "2 89\n", + "Name: Scores, dtype: Int64" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Find the second element in each array with list accessor\n", + "df['Scores'].list[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, "outputs": [ { "data": { @@ -228,7 +276,7 @@ "Name: Scores, dtype: Int64" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -243,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -261,7 +309,7 @@ "Name: Scores, dtype: Float64" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -274,7 +322,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -286,7 +334,7 @@ "Name: Scores, dtype: list[pyarrow]" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -299,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -361,7 +409,7 @@ "[3 rows x 3 columns]" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -394,14 +442,14 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/google/home/chelsealin/src/bigframes/venv/lib/python3.12/site-packages/google/cloud/bigquery/_pandas_helpers.py:570: UserWarning: Pyarrow could not determine the type of columns: bigframes_unnamed_index.\n", + "/usr/local/google/home/sycai/src/python-bigquery-dataframes/venv/lib/python3.11/site-packages/google/cloud/bigquery/_pandas_helpers.py:570: UserWarning: Pyarrow could not determine the type of columns: bigframes_unnamed_index.\n", " warnings.warn(\n" ] }, @@ -460,7 +508,7 @@ "[3 rows x 2 columns]" ] }, - "execution_count": 11, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -483,7 +531,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -494,7 +542,7 @@ "dtype: object" ] }, - "execution_count": 12, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -514,7 +562,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -525,7 +573,7 @@ "dtype: object" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -537,7 +585,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -549,7 +597,7 @@ "Name: City, dtype: string" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -562,7 +610,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -620,7 +668,7 @@ "[3 rows x 2 columns]" ] }, - "execution_count": 15, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -648,7 +696,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.1" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/noxfile.py b/noxfile.py index efe5a53082..5dbcdea583 100644 --- a/noxfile.py +++ b/noxfile.py @@ -16,10 +16,13 @@ from __future__ import absolute_import +import multiprocessing import os import pathlib import re import shutil +import time +import traceback from typing import Dict, List import warnings @@ -304,6 +307,7 @@ def run_system( print_duration=False, extra_pytest_options=(), timeout_seconds=900, + num_workers=20, ): """Run the system test suite.""" constraints_path = str( @@ -323,7 +327,7 @@ def run_system( pytest_cmd = [ "py.test", "--quiet", - "-n=20", + f"-n={num_workers}", # Any individual test taking longer than 15 mins will be terminated. f"--timeout={timeout_seconds}", # Log 20 slowest tests @@ -384,9 +388,15 @@ def doctest(session: nox.sessions.Session): run_system( session=session, prefix_name="doctest", - extra_pytest_options=("--doctest-modules", "third_party"), + extra_pytest_options=( + "--doctest-modules", + "third_party", + "--ignore", + "third_party/bigframes_vendored/ibis", + ), test_folder="bigframes", check_cov=True, + num_workers=5, ) @@ -747,6 +757,12 @@ def notebook(session: nox.Session): for nb in notebooks + list(notebooks_reg): assert os.path.exists(nb), nb + # Determine whether to enable multi-process mode based on the environment + # variable. If BENCHMARK_AND_PUBLISH is "true", it indicates we're running + # a benchmark, so we disable multi-process mode. If BENCHMARK_AND_PUBLISH + # is "false", we enable multi-process mode for faster execution. + multi_process_mode = os.getenv("BENCHMARK_AND_PUBLISH", "false") == "false" + try: # Populate notebook parameters and make a backup so that the notebooks # are runnable. @@ -755,23 +771,65 @@ def notebook(session: nox.Session): CURRENT_DIRECTORY / "scripts" / "notebooks_fill_params.py", *notebooks, ) + + # Shared flag using multiprocessing.Manager() to indicate if + # any process encounters an error. This flag may be updated + # across different processes. + error_flag = multiprocessing.Manager().Value("i", False) + processes = [] for notebook in notebooks: - session.run( + args = ( "python", "scripts/run_and_publish_benchmark.py", "--notebook", f"--benchmark-path={notebook}", ) - + if multi_process_mode: + process = multiprocessing.Process( + target=_run_process, + args=(session, args, error_flag), + ) + process.start() + processes.append(process) + # Adding a small delay between starting each + # process to avoid potential race conditions。 + time.sleep(1) + else: + session.run(*args) + + for process in processes: + process.join() + + processes = [] for notebook, regions in notebooks_reg.items(): for region in regions: - session.run( + args = ( "python", "scripts/run_and_publish_benchmark.py", "--notebook", f"--benchmark-path={notebook}", f"--region={region}", ) + if multi_process_mode: + process = multiprocessing.Process( + target=_run_process, + args=(session, args, error_flag), + ) + process.start() + processes.append(process) + # Adding a small delay between starting each + # process to avoid potential race conditions。 + time.sleep(1) + else: + session.run(*args) + + for process in processes: + process.join() + + # Check the shared error flag and raise an exception if any process + # reported an error + if error_flag.value: + raise Exception("Errors occurred in one or more subprocesses.") finally: # Prevent our notebook changes from getting checked in to git # accidentally. @@ -788,6 +846,15 @@ def notebook(session: nox.Session): ) +def _run_process(session: nox.Session, args, error_flag): + try: + session.run(*args) + except Exception: + traceback_str = traceback.format_exc() + print(traceback_str) + error_flag.value = True + + @nox.session(python=DEFAULT_PYTHON_VERSION) def benchmark(session: nox.Session): session.install("-e", ".[all]") diff --git a/tests/system/large/test_remote_function.py b/tests/system/large/test_remote_function.py index d6eefc1e31..77ea4627ec 100644 --- a/tests/system/large/test_remote_function.py +++ b/tests/system/large/test_remote_function.py @@ -1603,6 +1603,13 @@ def serialize_row(row): # bf_result.dtype is 'string[pyarrow]' while pd_result.dtype is 'object' # , ignore this mismatch by using check_dtype=False. pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False) + + # Let's make sure the read_gbq_function path works for this function + serialize_row_reuse = session.read_gbq_function( + serialize_row_remote.bigframes_remote_function, is_row_processor=True + ) + bf_result = scalars_df[columns].apply(serialize_row_reuse, axis=1).to_pandas() + pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False) finally: # clean up the gcp assets created for the remote function cleanup_remote_function_assets( @@ -2085,6 +2092,13 @@ def foo(x, y, z): pandas.testing.assert_series_equal( expected_result, bf_result, check_dtype=False, check_index_type=False ) + + # Let's make sure the read_gbq_function path works for this function + foo_reuse = session.read_gbq_function(foo.bigframes_remote_function) + bf_result = bf_df.apply(foo_reuse, axis=1).to_pandas() + pandas.testing.assert_series_equal( + expected_result, bf_result, check_dtype=False, check_index_type=False + ) finally: # clean up the gcp assets created for the remote function cleanup_remote_function_assets( diff --git a/tests/system/large/test_streaming.py b/tests/system/large/test_streaming.py index 391aec8533..e4992f8573 100644 --- a/tests/system/large/test_streaming.py +++ b/tests/system/large/test_streaming.py @@ -14,10 +14,13 @@ import time +import pytest + import bigframes import bigframes.streaming +@pytest.mark.flaky(retries=3, delay=10) def test_streaming_df_to_bigtable(session_load: bigframes.Session): # launch a continuous query job_id_prefix = "test_streaming_" @@ -51,6 +54,7 @@ def test_streaming_df_to_bigtable(session_load: bigframes.Session): query_job.cancel() +@pytest.mark.flaky(retries=3, delay=10) def test_streaming_df_to_pubsub(session_load: bigframes.Session): # launch a continuous query job_id_prefix = "test_streaming_pubsub_" diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 43e756019d..e3d2b51081 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -324,7 +324,7 @@ def test_create_load_text_embedding_generator_model( ("text-embedding-004", "text-multilingual-embedding-002"), ) @pytest.mark.flaky(retries=2) -def test_gemini_text_embedding_generator_predict_default_params_success( +def test_text_embedding_generator_predict_default_params_success( llm_text_df, model_name, session, bq_connection ): text_embedding_model = llm.TextEmbeddingGenerator( @@ -340,7 +340,13 @@ def test_gemini_text_embedding_generator_predict_default_params_success( @pytest.mark.parametrize( "model_name", - ("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"), + ( + "gemini-pro", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-001", + "gemini-1.5-flash-001", + ), ) def test_create_load_gemini_text_generator_model( dataset_id, model_name, session, bq_connection @@ -362,7 +368,13 @@ def test_create_load_gemini_text_generator_model( @pytest.mark.parametrize( "model_name", - ("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"), + ( + "gemini-pro", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-001", + "gemini-1.5-flash-001", + ), ) @pytest.mark.flaky(retries=2) def test_gemini_text_generator_predict_default_params_success( @@ -379,7 +391,13 @@ def test_gemini_text_generator_predict_default_params_success( @pytest.mark.parametrize( "model_name", - ("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"), + ( + "gemini-pro", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-001", + "gemini-1.5-flash-001", + ), ) @pytest.mark.flaky(retries=2) def test_gemini_text_generator_predict_with_params_success( diff --git a/tests/system/small/operations/test_lists.py b/tests/system/small/operations/test_lists.py new file mode 100644 index 0000000000..7ecf79dc6a --- /dev/null +++ b/tests/system/small/operations/test_lists.py @@ -0,0 +1,83 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import packaging.version +import pandas as pd +import pyarrow as pa +import pytest + +import bigframes.pandas as bpd + +from ...utils import assert_series_equal + + +@pytest.mark.parametrize( + ("key"), + [ + pytest.param(0, id="int"), + pytest.param(slice(None, None, None), id="default_start_slice"), + pytest.param(slice(0, None, 1), id="default_stop_slice"), + pytest.param(slice(0, 2, None), id="default_step_slice"), + ], +) +def test_getitem(key): + if packaging.version.Version(pd.__version__) < packaging.version.Version("2.2.0"): + pytest.skip( + "https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#series-list-accessor-for-pyarrow-list-data" + ) + data = [[1], [2, 3], [4, 5, 6]] + s = bpd.Series(data, dtype=pd.ArrowDtype(pa.list_(pa.int64()))) + pd_s = pd.Series(data, dtype=pd.ArrowDtype(pa.list_(pa.int64()))) + + bf_result = s.list[key].to_pandas() + pd_result = pd_s.list[key] + + assert_series_equal(pd_result, bf_result, check_dtype=False, check_index_type=False) + + +@pytest.mark.parametrize( + ("key", "expectation"), + [ + # Negative index + (-1, pytest.raises(NotImplementedError)), + # Slice with negative start + (slice(-1, None, None), pytest.raises(NotImplementedError)), + # Slice with negatiev end + (slice(0, -1, None), pytest.raises(NotImplementedError)), + # Slice with step not equal to 1 + (slice(0, 2, 2), pytest.raises(NotImplementedError)), + ], +) +def test_getitem_notsupported(key, expectation): + data = [[1], [2, 3], [4, 5, 6]] + s = bpd.Series(data, dtype=pd.ArrowDtype(pa.list_(pa.int64()))) + + with expectation as e: + assert s.list[key] == e + + +def test_len(): + if packaging.version.Version(pd.__version__) < packaging.version.Version("2.2.0"): + pytest.skip( + "https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#series-list-accessor-for-pyarrow-list-data" + ) + data = [[], [1], [1, 2], [1, 2, 3]] + s = bpd.Series(data, dtype=pd.ArrowDtype(pa.list_(pa.int64()))) + pd_s = pd.Series(data, dtype=pd.ArrowDtype(pa.list_(pa.int64()))) + + bf_result = s.list.len().to_pandas() + pd_result = pd_s.list.len() + + assert_series_equal(pd_result, bf_result, check_dtype=False, check_index_type=False) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index ddcf044911..f51b597650 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -4657,6 +4657,17 @@ def test_to_gbq_and_create_dataset(session, scalars_df_index, dataset_id_not_cre assert not loaded_scalars_df_index.empty +def test_to_gbq_table_labels(scalars_df_index): + destination_table = "bigframes-dev.bigframes_tests_sys.table_labels" + result_table = scalars_df_index.to_gbq( + destination_table, labels={"test": "labels"}, if_exists="replace" + ) + client = scalars_df_index._session.bqclient + table = client.get_table(result_table) + assert table.labels + assert table.labels["test"] == "labels" + + @pytest.mark.parametrize( ("col_names", "ignore_index"), [ diff --git a/tests/system/small/test_remote_function.py b/tests/system/small/test_remote_function.py index db573efa40..b000354ed4 100644 --- a/tests/system/small/test_remote_function.py +++ b/tests/system/small/test_remote_function.py @@ -671,12 +671,19 @@ def square1(x): @pytest.mark.flaky(retries=2, delay=120) -def test_read_gbq_function_runs_existing_udf(session, bigquery_client, dataset_id): +def test_read_gbq_function_runs_existing_udf(session): func = session.read_gbq_function("bqutil.fn.cw_lower_case_ascii_only") got = func("AURÉLIE") assert got == "aurÉlie" +@pytest.mark.flaky(retries=2, delay=120) +def test_read_gbq_function_runs_existing_udf_4_params(session): + func = session.read_gbq_function("bqutil.fn.cw_instr4") + got = func("TestStr123456Str", "Str", 1, 2) + assert got == 14 + + @pytest.mark.flaky(retries=2, delay=120) def test_read_gbq_function_reads_udfs(session, bigquery_client, dataset_id): dataset_ref = bigquery.DatasetReference.from_string(dataset_id) diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index 7458187a82..9a6783ee5c 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -3080,6 +3080,16 @@ def test_astype(scalars_df_index, scalars_pandas_df_index, column, to_type): pd.testing.assert_series_equal(bf_result, pd_result) +@skip_legacy_pandas +def test_astype_numeric_to_int(scalars_df_index, scalars_pandas_df_index): + column = "numeric_col" + to_type = "Int64" + bf_result = scalars_df_index[column].astype(to_type).to_pandas() + # Round to the nearest whole number to avoid TypeError + pd_result = scalars_pandas_df_index[column].round(0).astype(to_type) + pd.testing.assert_series_equal(bf_result, pd_result) + + @pytest.mark.parametrize( ("column", "to_type"), [ diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index 5b5db74ea6..ed3e38e6f8 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -1036,6 +1036,25 @@ def test_read_csv_local_w_usecols(session, scalars_pandas_df_index, engine): assert len(df.columns) == 1 +@pytest.mark.parametrize( + "engine", + [ + pytest.param( + "bigquery", + id="bq_engine", + marks=pytest.mark.xfail( + raises=NotImplementedError, + ), + ), + pytest.param(None, id="default_engine"), + ], +) +def test_read_csv_others(session, engine): + uri = "https://raw.githubusercontent.com/googleapis/python-bigquery-dataframes/main/tests/data/people.csv" + df = session.read_csv(uri, engine=engine) + assert len(df.columns) == 3 + + @pytest.mark.parametrize( "engine", [ diff --git a/tests/unit/test_constants.py b/tests/unit/test_constants.py new file mode 100644 index 0000000000..aabc09c388 --- /dev/null +++ b/tests/unit/test_constants.py @@ -0,0 +1,20 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bigframes.constants as constants + + +def test_feedback_link_includes_version(): + assert len(constants.BF_VERSION) > 0 + assert constants.BF_VERSION in constants.FEEDBACK_LINK diff --git a/tests/unit/test_formatting_helpers.py b/tests/unit/test_formatting_helpers.py index 9db9b372e2..3c966752c9 100644 --- a/tests/unit/test_formatting_helpers.py +++ b/tests/unit/test_formatting_helpers.py @@ -44,3 +44,14 @@ def test_wait_for_job_error_includes_feedback_link(): cap_exc.match("Test message 123.") cap_exc.match(constants.FEEDBACK_LINK) + + +def test_wait_for_job_error_includes_version(): + mock_job = mock.create_autospec(bigquery.LoadJob) + mock_job.result.side_effect = api_core_exceptions.BadRequest("Test message 123.") + + with pytest.raises(api_core_exceptions.BadRequest) as cap_exc: + formatting_helpers.wait_for_job(mock_job) + + cap_exc.match("Test message 123.") + cap_exc.match(constants.BF_VERSION) diff --git a/tests/unit/test_planner.py b/tests/unit/test_planner.py index 2e276d0f1a..84dd05ddaa 100644 --- a/tests/unit/test_planner.py +++ b/tests/unit/test_planner.py @@ -46,8 +46,8 @@ def test_session_aware_caching_project_filter(): """ Test that if a node is filtered by a column, the node is cached pre-filter and clustered by the filter column. """ - session_objects = [LEAF, LEAF.assign_constant("col_c", 4, pd.Int64Dtype())] - target = LEAF.assign_constant("col_c", 4, pd.Int64Dtype()).filter( + session_objects = [LEAF, LEAF.create_constant("col_c", 4, pd.Int64Dtype())] + target = LEAF.create_constant("col_c", 4, pd.Int64Dtype()).filter( ops.gt_op.as_expr("col_a", ex.const(3)) ) result, cluster_cols = planner.session_aware_cache_plan( @@ -61,14 +61,14 @@ def test_session_aware_caching_project_multi_filter(): """ Test that if a node is filtered by multiple columns, all of them are in the cluster cols """ - session_objects = [LEAF, LEAF.assign_constant("col_c", 4, pd.Int64Dtype())] + session_objects = [LEAF, LEAF.create_constant("col_c", 4, pd.Int64Dtype())] predicate_1a = ops.gt_op.as_expr("col_a", ex.const(3)) predicate_1b = ops.lt_op.as_expr("col_a", ex.const(55)) predicate_1 = ops.and_op.as_expr(predicate_1a, predicate_1b) predicate_3 = ops.eq_op.as_expr("col_b", ex.const(1)) target = ( LEAF.filter(predicate_1) - .assign_constant("col_c", 4, pd.Int64Dtype()) + .create_constant("col_c", 4, pd.Int64Dtype()) .filter(predicate_3) ) result, cluster_cols = planner.session_aware_cache_plan( @@ -84,8 +84,8 @@ def test_session_aware_caching_unusable_filter(): Most filters with multiple column references cannot be used for scan pruning, as they cannot be converted to fixed value ranges. """ - session_objects = [LEAF, LEAF.assign_constant("col_c", 4, pd.Int64Dtype())] - target = LEAF.assign_constant("col_c", 4, pd.Int64Dtype()).filter( + session_objects = [LEAF, LEAF.create_constant("col_c", 4, pd.Int64Dtype())] + target = LEAF.create_constant("col_c", 4, pd.Int64Dtype()).filter( ops.gt_op.as_expr("col_a", "col_b") ) result, cluster_cols = planner.session_aware_cache_plan( @@ -101,12 +101,12 @@ def test_session_aware_caching_fork_after_window_op(): Windowing is expensive, so caching should always compute the window function, in order to avoid later recomputation. """ - other = LEAF.promote_offsets("offsets_col").assign_constant( + other = LEAF.promote_offsets("offsets_col").create_constant( "col_d", 5, pd.Int64Dtype() ) target = ( LEAF.promote_offsets("offsets_col") - .assign_constant("col_c", 4, pd.Int64Dtype()) + .create_constant("col_c", 4, pd.Int64Dtype()) .filter( ops.eq_op.as_expr("col_a", ops.add_op.as_expr(ex.const(4), ex.const(3))) ) diff --git a/third_party/bigframes_vendored/constants.py b/third_party/bigframes_vendored/constants.py index 0d4a7d1df6..91084b38f9 100644 --- a/third_party/bigframes_vendored/constants.py +++ b/third_party/bigframes_vendored/constants.py @@ -16,10 +16,14 @@ This module should not depend on any others in the package. """ +import bigframes.version + +BF_VERSION = bigframes.version.__version__ FEEDBACK_LINK = ( "Share your usecase with the BigQuery DataFrames team at the " "https://bit.ly/bigframes-feedback survey." + f"You are currently running BigFrames version {BF_VERSION}" ) ABSTRACT_METHOD_ERROR_MESSAGE = ( diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py new file mode 100644 index 0000000000..f917ef950d --- /dev/null +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py @@ -0,0 +1,1259 @@ +# Contains code from https://github.com/ibis-project/ibis/blob/main/ibis/backends/bigquery/__init__.py + +"""BigQuery public API.""" + +from __future__ import annotations + +import concurrent.futures +import contextlib +import glob +import os +import re +from typing import Any, Optional, TYPE_CHECKING + +from bigframes_vendored.ibis.backends.bigquery.datatypes import BigQueryType +import google.api_core.exceptions +import google.auth.credentials +import google.cloud.bigquery as bq +import google.cloud.bigquery_storage_v1 as bqstorage +import ibis +from ibis import util +from ibis.backends import CanCreateDatabase, CanCreateSchema +from ibis.backends.bigquery.client import ( + bigquery_param, + parse_project_and_dataset, + rename_partitioned_column, + schema_from_bigquery_table, +) +from ibis.backends.bigquery.datatypes import BigQuerySchema +from ibis.backends.sql import SQLBackend +import ibis.backends.sql.compilers as sc +import ibis.common.exceptions as com +import ibis.expr.operations as ops +import ibis.expr.schema as sch +import ibis.expr.types as ir +import pydata_google_auth +from pydata_google_auth import cache +import sqlglot as sg +import sqlglot.expressions as sge + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from pathlib import Path + from urllib.parse import ParseResult + + import pandas as pd + import polars as pl + import pyarrow as pa + + +SCOPES = ["https://www.googleapis.com/auth/bigquery"] +EXTERNAL_DATA_SCOPES = [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/drive", +] +CLIENT_ID = "546535678771-gvffde27nd83kfl6qbrnletqvkdmsese.apps.googleusercontent.com" +CLIENT_SECRET = "iU5ohAF2qcqrujegE3hQ1cPt" # noqa: S105 + + +def _create_user_agent(application_name: str) -> str: + user_agent = [] + + if application_name: + user_agent.append(application_name) + + user_agent_default_template = f"ibis/{ibis.__version__}" + user_agent.append(user_agent_default_template) + + return " ".join(user_agent) + + +def _create_client_info(application_name): + from google.api_core.client_info import ClientInfo + + return ClientInfo(user_agent=_create_user_agent(application_name)) + + +def _create_client_info_gapic(application_name): + from google.api_core.gapic_v1.client_info import ClientInfo + + return ClientInfo(user_agent=_create_user_agent(application_name)) + + +_MEMTABLE_PATTERN = re.compile( + r"^_?ibis_(?:[A-Za-z_][A-Za-z_0-9]*)_memtable_[a-z0-9]{26}$" +) + + +def _qualify_memtable( + node: sge.Expression, *, dataset: str | None, project: str | None +) -> sge.Expression: + """Add a BigQuery dataset and project to memtable references.""" + if isinstance(node, sge.Table) and _MEMTABLE_PATTERN.match(node.name) is not None: + node.args["db"] = dataset + node.args["catalog"] = project + # make sure to quote table location + node = _force_quote_table(node) + return node + + +def _remove_null_ordering_from_unsupported_window( + node: sge.Expression, +) -> sge.Expression: + """Remove null ordering in window frame clauses not supported by BigQuery. + + BigQuery has only partial support for NULL FIRST/LAST in RANGE windows so + we remove it from any window frame clause that doesn't support it. + + Here's the support matrix: + + ✅ sum(x) over (order by y desc nulls last) + 🚫 sum(x) over (order by y asc nulls last) + ✅ sum(x) over (order by y asc nulls first) + 🚫 sum(x) over (order by y desc nulls first) + """ + if isinstance(node, sge.Window): + order = node.args.get("order") + if order is not None: + for key in order.args["expressions"]: + kargs = key.args + if kargs.get("desc") is True and kargs.get("nulls_first", False): + kargs["nulls_first"] = False + elif kargs.get("desc") is False and not kargs.setdefault( + "nulls_first", True + ): + kargs["nulls_first"] = True + return node + + +def _force_quote_table(table: sge.Table) -> sge.Table: + """Force quote all the parts of a bigquery path. + + https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers + + my-table is OK, but not mydataset.my-table + + mytable-287 is OK, but not mytable-287a + + Just quote everything. + """ + for key in ("this", "db", "catalog"): + if (val := table.args[key]) is not None: + if isinstance(val, sg.exp.Identifier) and not val.quoted: + val.args["quoted"] = True + else: + table.args[key] = sg.to_identifier(val, quoted=True) + return table + + +class Backend(SQLBackend, CanCreateDatabase, CanCreateSchema): + name = "bigquery" + compiler = sc.bigquery.compiler + supports_in_memory_tables = True + supports_python_udfs = False + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.__session_dataset: bq.DatasetReference | None = None + + @property + def _session_dataset(self): + if self.__session_dataset is None: + self.__session_dataset = self._make_session() + return self.__session_dataset + + def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: + raw_name = op.name + + session_dataset = self._session_dataset + project = session_dataset.project + dataset = session_dataset.dataset_id + + table_ref = bq.TableReference(session_dataset, raw_name) + try: + self.client.get_table(table_ref) + except google.api_core.exceptions.NotFound: + table_id = sg.table( + raw_name, db=dataset, catalog=project, quoted=False + ).sql(dialect=self.name) + bq_schema = BigQuerySchema.from_ibis(op.schema) + load_job = self.client.load_table_from_dataframe( + op.data.to_frame(), + table_id, + job_config=bq.LoadJobConfig( + # fail if the table already exists and contains data + write_disposition=bq.WriteDisposition.WRITE_EMPTY, + schema=bq_schema, + ), + ) + load_job.result() + + def _read_file( + self, + path: str | Path, + *, + table_name: str | None = None, + job_config: bq.LoadJobConfig, + ) -> ir.Table: + self._make_session() + + if table_name is None: + table_name = util.gen_name(f"bq_read_{job_config.source_format}") + + table_ref = self._session_dataset.table(table_name) + + database = self._session_dataset.dataset_id + catalog = self._session_dataset.project + + # drop the table if it exists + # + # we could do this with write_disposition = WRITE_TRUNCATE but then the + # concurrent append jobs aren't possible + # + # dropping the table first means all write_dispositions can be + # WRITE_APPEND + self.drop_table(table_name, database=(catalog, database), force=True) + + if os.path.isdir(path): + raise NotImplementedError("Reading from a directory is not supported.") + elif str(path).startswith("gs://"): + load_job = self.client.load_table_from_uri( + path, table_ref, job_config=job_config + ) + load_job.result() + else: + + def load(file: str) -> None: + with open(file, mode="rb") as f: + load_job = self.client.load_table_from_file( + f, table_ref, job_config=job_config + ) + load_job.result() + + job_config.write_disposition = bq.WriteDisposition.WRITE_APPEND + + with concurrent.futures.ThreadPoolExecutor() as executor: + for fut in concurrent.futures.as_completed( + executor.submit(load, file) for file in glob.glob(str(path)) + ): + fut.result() + + return self.table(table_name, database=(catalog, database)) + + def read_parquet( + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ): + """Read Parquet data into a BigQuery table. + + Parameters + ---------- + path + Path to a Parquet file on GCS or the local filesystem. Globs are supported. + table_name + Optional table name + kwargs + Additional keyword arguments passed to `google.cloud.bigquery.LoadJobConfig`. + + Returns + ------- + Table + An Ibis table expression + + """ + return self._read_file( + path, + table_name=table_name, + job_config=bq.LoadJobConfig( + source_format=bq.SourceFormat.PARQUET, **kwargs + ), + ) + + def read_csv( + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ) -> ir.Table: + """Read CSV data into a BigQuery table. + + Parameters + ---------- + path + Path to a CSV file on GCS or the local filesystem. Globs are supported. + table_name + Optional table name + kwargs + Additional keyword arguments passed to + `google.cloud.bigquery.LoadJobConfig`. + + Returns + ------- + Table + An Ibis table expression + + """ + job_config = bq.LoadJobConfig( + source_format=bq.SourceFormat.CSV, + autodetect=True, + skip_leading_rows=1, + **kwargs, + ) + return self._read_file(path, table_name=table_name, job_config=job_config) + + def read_json( + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ) -> ir.Table: + """Read newline-delimited JSON data into a BigQuery table. + + Parameters + ---------- + path + Path to a newline-delimited JSON file on GCS or the local + filesystem. Globs are supported. + table_name + Optional table name + kwargs + Additional keyword arguments passed to + `google.cloud.bigquery.LoadJobConfig`. + + Returns + ------- + Table + An Ibis table expression + + """ + job_config = bq.LoadJobConfig( + source_format=bq.SourceFormat.NEWLINE_DELIMITED_JSON, + autodetect=True, + **kwargs, + ) + return self._read_file(path, table_name=table_name, job_config=job_config) + + def _from_url(self, url: ParseResult, **kwargs): + return self.connect( + project_id=url.netloc or kwargs.get("project_id", [""])[0], + dataset_id=url.path[1:] or kwargs.get("dataset_id", [""])[0], + **kwargs, + ) + + def do_connect( + self, + project_id: str | None = None, + dataset_id: str = "", + credentials: google.auth.credentials.Credentials | None = None, + application_name: str | None = None, + auth_local_webserver: bool = True, + auth_external_data: bool = False, + auth_cache: str = "default", + partition_column: str | None = "PARTITIONTIME", + client: bq.Client | None = None, + storage_client: bqstorage.BigQueryReadClient | None = None, + location: str | None = None, + ) -> Backend: + """Create a `Backend` for use with Ibis. + + Parameters + ---------- + project_id + A BigQuery project id. + dataset_id + A dataset id that lives inside of the project indicated by + `project_id`. + credentials + Optional credentials. + application_name + A string identifying your application to Google API endpoints. + auth_local_webserver + Use a local webserver for the user authentication. Binds a + webserver to an open port on localhost between 8080 and 8089, + inclusive, to receive authentication token. If not set, defaults to + False, which requests a token via the console. + auth_external_data + Authenticate using additional scopes required to `query external + data sources + `_, + such as Google Sheets, files in Google Cloud Storage, or files in + Google Drive. If not set, defaults to False, which requests the + default BigQuery scopes. + auth_cache + Selects the behavior of the credentials cache. + + `'default'`` + Reads credentials from disk if available, otherwise + authenticates and caches credentials to disk. + + `'reauth'`` + Authenticates and caches credentials to disk. + + `'none'`` + Authenticates and does **not** cache credentials. + + Defaults to `'default'`. + partition_column + Identifier to use instead of default `_PARTITIONTIME` partition + column. Defaults to `'PARTITIONTIME'`. + client + A `Client` from the `google.cloud.bigquery` package. If not + set, one is created using the `project_id` and `credentials`. + storage_client + A `BigQueryReadClient` from the + `google.cloud.bigquery_storage_v1` package. If not set, one is + created using the `project_id` and `credentials`. + location + Default location for BigQuery objects. + + Returns + ------- + Backend + An instance of the BigQuery backend. + + """ + default_project_id = client.project if client is not None else project_id + + # Only need `credentials` to create a `client` and + # `storage_client`, so only one or the other needs to be set. + if (client is None or storage_client is None) and credentials is None: + scopes = SCOPES + if auth_external_data: + scopes = EXTERNAL_DATA_SCOPES + + if auth_cache == "default": + credentials_cache = cache.ReadWriteCredentialsCache( + filename="ibis.json" + ) + elif auth_cache == "reauth": + credentials_cache = cache.WriteOnlyCredentialsCache( + filename="ibis.json" + ) + elif auth_cache == "none": + credentials_cache = cache.NOOP + else: + raise ValueError( + f"Got unexpected value for auth_cache = '{auth_cache}'. " + "Expected one of 'default', 'reauth', or 'none'." + ) + + credentials, default_project_id = pydata_google_auth.default( + scopes, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + credentials_cache=credentials_cache, + use_local_webserver=auth_local_webserver, + ) + + project_id = project_id or default_project_id + + ( + self.data_project, + self.billing_project, + self.dataset, + ) = parse_project_and_dataset(project_id, dataset_id) + + if client is not None: + self.client = client + else: + self.client = bq.Client( + project=self.billing_project, + credentials=credentials, + client_info=_create_client_info(application_name), + location=location, + ) + + if self.client.default_query_job_config is None: + self.client.default_query_job_config = bq.QueryJobConfig() + + self.client.default_query_job_config.use_legacy_sql = False + self.client.default_query_job_config.allow_large_results = True + + if storage_client is not None: + self.storage_client = storage_client + else: + self.storage_client = bqstorage.BigQueryReadClient( + credentials=credentials, + client_info=_create_client_info_gapic(application_name), + ) + + self.partition_column = partition_column + + @util.experimental + @classmethod + def from_connection( + cls, + client: bq.Client, + partition_column: str | None = "PARTITIONTIME", + storage_client: bqstorage.BigQueryReadClient | None = None, + dataset_id: str = "", + ) -> Backend: + """Create a BigQuery `Backend` from an existing `Client`. + + Parameters + ---------- + client + A `Client` from the `google.cloud.bigquery` package. + partition_column + Identifier to use instead of default `_PARTITIONTIME` partition + column. Defaults to `'PARTITIONTIME'`. + storage_client + A `BigQueryReadClient` from the `google.cloud.bigquery_storage_v1` + package. + dataset_id + A dataset id that lives inside of the project attached to `client`. + """ + return ibis.bigquery.connect( + client=client, + partition_column=partition_column, + storage_client=storage_client, + dataset_id=dataset_id, + ) + + def disconnect(self) -> None: + self.client.close() + + def _parse_project_and_dataset(self, dataset) -> tuple[str, str]: + if isinstance(dataset, sge.Table): + dataset = dataset.sql(self.dialect) + if not dataset and not self.dataset: + raise ValueError("Unable to determine BigQuery dataset.") + project, _, dataset = parse_project_and_dataset( + self.billing_project, + dataset or f"{self.data_project}.{self.dataset}", + ) + return project, dataset + + @property + def project_id(self): + return self.data_project + + @property + def dataset_id(self): + return self.dataset + + def create_database( + self, + name: str, + catalog: str | None = None, + force: bool = False, + collate: str | None = None, + **options: Any, + ) -> None: + properties = [ + sge.Property(this=sg.to_identifier(name), value=sge.convert(value)) + for name, value in (options or {}).items() + ] + + if collate is not None: + properties.append( + sge.CollateProperty(this=sge.convert(collate), default=True) + ) + + stmt = sge.Create( + kind="SCHEMA", + this=sg.table(name, db=catalog), + exists=force, + properties=sge.Properties(expressions=properties), + ) + + self.raw_sql(stmt.sql(self.name)) + + def drop_database( + self, + name: str, + catalog: str | None = None, + force: bool = False, + cascade: bool = False, + ) -> None: + """Drop a BigQuery dataset.""" + stmt = sge.Drop( + kind="SCHEMA", + this=sg.table(name, db=catalog), + exists=force, + cascade=cascade, + ) + + self.raw_sql(stmt.sql(self.name)) + + def table( + self, name: str, database: str | None = None, schema: str | None = None + ) -> ir.Table: + table_loc = self._warn_and_create_table_loc(database, schema) + table = sg.parse_one(f"`{name}`", into=sge.Table, read=self.name) + + # Bigquery, unlike other backends, had existing support for specifying + # table hierarchy in the table name, e.g. con.table("dataset.table_name") + # so here we have an extra layer of disambiguation to handle. + + # Default `catalog` to None unless we've parsed it out of the database/schema kwargs + # Raise if there are path specifications in both the name and as a kwarg + catalog = table_loc.args["catalog"] # args access will return None, not '' + if table.catalog: + if table_loc.catalog: + raise com.IbisInputError( + "Cannot specify catalog both in the table name and as an argument" + ) + else: + catalog = table.catalog + + # Default `db` to None unless we've parsed it out of the database/schema kwargs + db = table_loc.args["db"] # args access will return None, not '' + if table.db: + if table_loc.db: + raise com.IbisInputError( + "Cannot specify database both in the table name and as an argument" + ) + else: + db = table.db + + database = ( + sg.table(None, db=db, catalog=catalog, quoted=False).sql(dialect=self.name) + or None + ) + + project, dataset = self._parse_project_and_dataset(database) + + bq_table = self.client.get_table( + bq.TableReference( + bq.DatasetReference(project=project, dataset_id=dataset), + table.name, + ) + ) + + node = ops.DatabaseTable( + table.name, + # https://cloud.google.com/bigquery/docs/querying-wildcard-tables#filtering_selected_tables_using_table_suffix + schema=schema_from_bigquery_table(bq_table, wildcard=table.name[-1] == "*"), + source=self, + namespace=ops.Namespace(database=dataset, catalog=project), + ) + table_expr = node.to_expr() + return rename_partitioned_column(table_expr, bq_table, self.partition_column) + + def _make_session(self) -> tuple[str, str]: + if (client := getattr(self, "client", None)) is not None: + job_config = bq.QueryJobConfig(use_query_cache=False) + query = client.query( + "SELECT 1", job_config=job_config, project=self.billing_project + ) + query.result() + + return bq.DatasetReference( + project=query.destination.project, + dataset_id=query.destination.dataset_id, + ) + return None + + def _get_schema_using_query(self, query: str) -> sch.Schema: + job = self.client.query( + query, + job_config=bq.QueryJobConfig(dry_run=True, use_query_cache=False), + project=self.billing_project, + ) + return BigQuerySchema.to_ibis(job.schema) + + def raw_sql(self, query: str, params=None, page_size: int | None = None): + query_parameters = [ + bigquery_param( + param.type(), + value, + ( + param.get_name() + if not isinstance(op := param.op(), ops.Alias) + else op.arg.name + ), + ) + for param, value in (params or {}).items() + ] + with contextlib.suppress(AttributeError): + query = query.sql(self.dialect) + + job_config = bq.job.QueryJobConfig(query_parameters=query_parameters or []) + return self.client.query_and_wait( + query, + job_config=job_config, + project=self.billing_project, + page_size=page_size, + ) + + @property + def current_catalog(self) -> str: + return self.data_project + + @property + def current_database(self) -> str | None: + return self.dataset + + def compile( + self, + expr: ir.Expr, + limit: str | None = None, + params=None, + pretty: bool = True, + **kwargs: Any, + ): + """Compile an Ibis expression to a SQL string.""" + session_dataset = self._session_dataset + query = self.compiler.to_sqlglot( + expr, + limit=limit, + params=params, + session_dataset_id=getattr(session_dataset, "dataset_id", None), + session_project=getattr(session_dataset, "project", None), + **kwargs, + ) + queries = query if isinstance(query, list) else [query] + sql = ";\n".join(query.sql(self.dialect, pretty=pretty) for query in queries) + self._log(sql) + return sql + + def execute(self, expr, params=None, limit="default", **kwargs): + """Compile and execute the given Ibis expression. + + Compile and execute Ibis expression using this backend client + interface, returning results in-memory in the appropriate object type + + Parameters + ---------- + expr + Ibis expression to execute + limit + Retrieve at most this number of values/rows. Overrides any limit + already set on the expression. + params + Query parameters + kwargs + Extra arguments specific to the backend + + Returns + ------- + pd.DataFrame | pd.Series | scalar + Output from execution + + """ + from ibis.backends.bigquery.converter import BigQueryPandasData + + self._run_pre_execute_hooks(expr) + + schema = expr.as_table().schema() - ibis.schema({"_TABLE_SUFFIX": "string"}) + + sql = self.compile(expr, limit=limit, params=params, **kwargs) + self._log(sql) + query = self.raw_sql(sql, params=params, **kwargs) + + arrow_t = query.to_arrow( + progress_bar_type=None, bqstorage_client=self.storage_client + ) + + result = BigQueryPandasData.convert_table( + arrow_t.to_pandas(timestamp_as_object=True), schema + ) + + return expr.__pandas_result__(result, schema=schema) + + def insert( + self, + table_name: str, + obj: pd.DataFrame | ir.Table | list | dict, + schema: str | None = None, + database: str | None = None, + overwrite: bool = False, + ): + """Insert data into a table. + + Parameters + ---------- + table_name + The name of the table to which data needs will be inserted + obj + The source data or expression to insert + schema + The name of the schema that the table is located in + database + Name of the attached database that the table is located in. + overwrite + If `True` then replace existing contents of table + + """ + table_loc = self._warn_and_create_table_loc(database, schema) + catalog, db = self._to_catalog_db_tuple(table_loc) + if catalog is None: + catalog = self.current_catalog + if db is None: + db = self.current_database + + return super().insert( + table_name, + obj, + database=(catalog, db), + overwrite=overwrite, + ) + + def to_pyarrow( + self, + expr: ir.Expr, + *, + params: Mapping[ir.Scalar, Any] | None = None, + limit: int | str | None = None, + **kwargs: Any, + ) -> pa.Table: + self._import_pyarrow() + self._register_in_memory_tables(expr) + sql = self.compile(expr, limit=limit, params=params, **kwargs) + self._log(sql) + query = self.raw_sql(sql, params=params, **kwargs) + table = query.to_arrow( + progress_bar_type=None, bqstorage_client=self.storage_client + ) + table = table.rename_columns(list(expr.as_table().schema().names)) + return expr.__pyarrow_result__(table) + + def to_pyarrow_batches( + self, + expr: ir.Expr, + *, + params: Mapping[ir.Scalar, Any] | None = None, + limit: int | str | None = None, + chunk_size: int = 1_000_000, + **kwargs: Any, + ): + pa = self._import_pyarrow() + + schema = expr.as_table().schema() + + self._register_in_memory_tables(expr) + sql = self.compile(expr, limit=limit, params=params, **kwargs) + self._log(sql) + query = self.raw_sql(sql, params=params, page_size=chunk_size, **kwargs) + batch_iter = query.to_arrow_iterable(bqstorage_client=self.storage_client) + return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), batch_iter) + + def _gen_udf_name(self, name: str, schema: Optional[str]) -> str: + func = ".".join(filter(None, (schema, name))) + if "." in func: + return ".".join(f"`{part}`" for part in func.split(".")) + return func + + def get_schema( + self, + name, + *, + catalog: str | None = None, + database: str | None = None, + ): + table_ref = bq.TableReference( + bq.DatasetReference( + project=catalog or self.data_project, + dataset_id=database or self.current_database, + ), + name, + ) + return schema_from_bigquery_table( + self.client.get_table(table_ref), + # https://cloud.google.com/bigquery/docs/querying-wildcard-tables#filtering_selected_tables_using_table_suffix + wildcard=name[-1] == "*", + ) + + def list_databases( + self, like: str | None = None, catalog: str | None = None + ) -> list[str]: + results = [ + dataset.dataset_id + for dataset in self.client.list_datasets( + project=catalog if catalog is not None else self.data_project + ) + ] + return self._filter_with_like(results, like) + + def list_tables( + self, + like: str | None = None, + database: tuple[str, str] | str | None = None, + schema: str | None = None, + ) -> list[str]: + """List the tables in the database. + + Parameters + ---------- + like + A pattern to use for listing tables. + database + The database location to perform the list against. + + By default uses the current `dataset` (`self.current_database`) and + `project` (`self.current_catalog`). + + To specify a table in a separate BigQuery dataset, you can pass in the + dataset and project as a string `"dataset.project"`, or as a tuple of + strings `("dataset", "project")`. + + ::: {.callout-note} + ## Ibis does not use the word `schema` to refer to database hierarchy. + + A collection of tables is referred to as a `database`. + A collection of `database` is referred to as a `catalog`. + + These terms are mapped onto the corresponding features in each + backend (where available), regardless of whether the backend itself + uses the same terminology. + ::: + schema + [deprecated] The schema (dataset) inside `database` to perform the list against. + """ + table_loc = self._warn_and_create_table_loc(database, schema) + + project, dataset = self._parse_project_and_dataset(table_loc) + dataset_ref = bq.DatasetReference(project, dataset) + result = [table.table_id for table in self.client.list_tables(dataset_ref)] + return self._filter_with_like(result, like) + + def set_database(self, name): + self.data_project, self.dataset = self._parse_project_and_dataset(name) + + @property + def version(self): + return bq.__version__ + + def create_table( + self, + name: str, + obj: ir.Table + | pd.DataFrame + | pa.Table + | pl.DataFrame + | pl.LazyFrame + | None = None, + *, + schema: sch.SchemaLike | None = None, + database: str | None = None, + temp: bool = False, + overwrite: bool = False, + default_collate: str | None = None, + partition_by: str | None = None, + cluster_by: Iterable[str] | None = None, + options: Mapping[str, Any] | None = None, + ) -> ir.Table: + """Create a table in BigQuery. + + Parameters + ---------- + name + Name of the table to create + obj + The data with which to populate the table; optional, but one of `obj` + or `schema` must be specified + schema + The schema of the table to create; optional, but one of `obj` or + `schema` must be specified + database + The BigQuery *dataset* in which to create the table; optional + temp + Whether the table is temporary + overwrite + If `True`, replace the table if it already exists, otherwise fail if + the table exists + default_collate + Default collation for string columns. See BigQuery's documentation + for more details: https://cloud.google.com/bigquery/docs/reference/standard-sql/collation-concepts + partition_by + Partition the table by the given expression. See BigQuery's documentation + for more details: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#partition_expression + cluster_by + List of columns to cluster the table by. See BigQuery's documentation + for more details: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#clustering_column_list + options + BigQuery-specific table options; see the BigQuery documentation for + details: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#table_option_list + + Returns + ------- + Table + The table that was just created + + """ + if obj is None and schema is None: + raise com.IbisError("One of the `schema` or `obj` parameter is required") + if schema is not None: + schema = ibis.schema(schema) + + if isinstance(obj, ir.Table) and schema is not None: + if not schema.equals(obj.schema()): + raise com.IbisTypeError( + "Provided schema and Ibis table schema are incompatible. Please " + "align the two schemas, or provide only one of the two arguments." + ) + + project_id, dataset = self._parse_project_and_dataset(database) + + properties = [] + + if default_collate is not None: + properties.append( + sge.CollateProperty(this=sge.convert(default_collate), default=True) + ) + + if partition_by is not None: + properties.append( + sge.PartitionedByProperty( + this=sge.Tuple( + expressions=list(map(sg.to_identifier, partition_by)) + ) + ) + ) + + if cluster_by is not None: + properties.append( + sge.Cluster(expressions=list(map(sg.to_identifier, cluster_by))) + ) + + properties.extend( + sge.Property(this=sg.to_identifier(name), value=sge.convert(value)) + for name, value in (options or {}).items() + ) + + if obj is not None and not isinstance(obj, ir.Table): + obj = ibis.memtable(obj, schema=schema) + + if obj is not None: + self._register_in_memory_tables(obj) + + if temp: + dataset = self._session_dataset.dataset_id + if database is not None: + raise com.IbisInputError("Cannot specify database for temporary table") + database = self._session_dataset.project + else: + dataset = database or self.current_database + + try: + table = sg.parse_one(name, into=sge.Table, read="bigquery") + except sg.ParseError: + table = sg.table( + name, + db=dataset, + catalog=project_id, + quoted=self.compiler.quoted, + ) + else: + if table.args["db"] is None: + table.args["db"] = dataset + + if table.args["catalog"] is None: + table.args["catalog"] = project_id + + table = _force_quote_table(table) + + column_defs = [ + sge.ColumnDef( + this=sg.to_identifier(name, quoted=self.compiler.quoted), + kind=BigQueryType.from_ibis(typ), + constraints=( + None + if typ.nullable or typ.is_array() + else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())] + ), + ) + for name, typ in (schema or {}).items() + ] + + stmt = sge.Create( + kind="TABLE", + this=sge.Schema(this=table, expressions=column_defs or None), + replace=overwrite, + properties=sge.Properties(expressions=properties), + expression=None if obj is None else self.compile(obj), + ) + + sql = stmt.sql(self.name) + + self.raw_sql(sql) + return self.table(table.name, database=(table.catalog, table.db)) + + def drop_table( + self, + name: str, + *, + schema: str | None = None, + database: tuple[str | str] | str | None = None, + force: bool = False, + ) -> None: + table_loc = self._warn_and_create_table_loc(database, schema) + catalog, db = self._to_catalog_db_tuple(table_loc) + stmt = sge.Drop( + kind="TABLE", + this=sg.table( + name, + db=db or self.current_database, + catalog=catalog or self.billing_project, + ), + exists=force, + ) + self.raw_sql(stmt.sql(self.name)) + + def create_view( + self, + name: str, + obj: ir.Table, + *, + schema: str | None = None, + database: str | None = None, + overwrite: bool = False, + ) -> ir.Table: + table_loc = self._warn_and_create_table_loc(database, schema) + catalog, db = self._to_catalog_db_tuple(table_loc) + + stmt = sge.Create( + kind="VIEW", + this=sg.table( + name, + db=db or self.current_database, + catalog=catalog or self.billing_project, + ), + expression=self.compile(obj), + replace=overwrite, + ) + self._register_in_memory_tables(obj) + self.raw_sql(stmt.sql(self.name)) + return self.table(name, database=(catalog, database)) + + def drop_view( + self, + name: str, + *, + schema: str | None = None, + database: str | None = None, + force: bool = False, + ) -> None: + table_loc = self._warn_and_create_table_loc(database, schema) + catalog, db = self._to_catalog_db_tuple(table_loc) + + stmt = sge.Drop( + kind="VIEW", + this=sg.table( + name, + db=db or self.current_database, + catalog=catalog or self.billing_project, + ), + exists=force, + ) + self.raw_sql(stmt.sql(self.name)) + + def _drop_cached_table(self, name): + self.drop_table( + name, + database=(self._session_dataset.project, self._session_dataset.dataset_id), + force=True, + ) + + def _register_udfs(self, expr: ir.Expr) -> None: + """No op because UDFs made with CREATE TEMPORARY FUNCTION must be followed by a query.""" + + @contextlib.contextmanager + def _safe_raw_sql(self, *args, **kwargs): + yield self.raw_sql(*args, **kwargs) + + # TODO: remove when the schema kwarg is removed + def _warn_and_create_table_loc(self, database=None, schema=None): + if schema is not None: + self._warn_schema() + if database is not None and schema is not None: + if isinstance(database, str): + table_loc = f"{database}.{schema}" + elif isinstance(database, tuple): + table_loc = database + schema + elif schema is not None: + table_loc = schema + elif database is not None: + table_loc = database + else: + table_loc = None + + table_loc = self._to_sqlglot_table(table_loc) + + if table_loc is not None: + if (sg_cat := table_loc.args["catalog"]) is not None: + sg_cat.args["quoted"] = False + if (sg_db := table_loc.args["db"]) is not None: + sg_db.args["quoted"] = False + + return table_loc + + +def compile(expr, params=None, **kwargs): + """Compile an expression for BigQuery.""" + backend = Backend() + return backend.compile(expr, params=params, **kwargs) + + +def connect( + project_id: str | None = None, + dataset_id: str = "", + credentials: google.auth.credentials.Credentials | None = None, + application_name: str | None = None, + auth_local_webserver: bool = False, + auth_external_data: bool = False, + auth_cache: str = "default", + partition_column: str | None = "PARTITIONTIME", +) -> Backend: + """Create a :class:`Backend` for use with Ibis. + + Parameters + ---------- + project_id + A BigQuery project id. + dataset_id + A dataset id that lives inside of the project indicated by + `project_id`. + credentials + Optional credentials. + application_name + A string identifying your application to Google API endpoints. + auth_local_webserver + Use a local webserver for the user authentication. Binds a + webserver to an open port on localhost between 8080 and 8089, + inclusive, to receive authentication token. If not set, defaults + to False, which requests a token via the console. + auth_external_data + Authenticate using additional scopes required to `query external + data sources + `_, + such as Google Sheets, files in Google Cloud Storage, or files in + Google Drive. If not set, defaults to False, which requests the + default BigQuery scopes. + auth_cache + Selects the behavior of the credentials cache. + + `'default'`` + Reads credentials from disk if available, otherwise + authenticates and caches credentials to disk. + + `'reauth'`` + Authenticates and caches credentials to disk. + + `'none'`` + Authenticates and does **not** cache credentials. + + Defaults to `'default'`. + partition_column + Identifier to use instead of default `_PARTITIONTIME` partition + column. Defaults to `'PARTITIONTIME'`. + + Returns + ------- + Backend + An instance of the BigQuery backend + + """ + backend = Backend() + return backend.connect( + project_id=project_id, + dataset_id=dataset_id, + credentials=credentials, + application_name=application_name, + auth_local_webserver=auth_local_webserver, + auth_external_data=auth_external_data, + auth_cache=auth_cache, + partition_column=partition_column, + ) + + +__all__ = [ + "Backend", + "compile", + "connect", +] diff --git a/third_party/bigframes_vendored/ibis/backends/sql/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/__init__.py new file mode 100644 index 0000000000..b8a477dd4d --- /dev/null +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/__init__.py @@ -0,0 +1,7 @@ +# Contains code from https://github.com/ibis-project/ibis/blob/main/ibis/backends/sql/compilers/__init__.py + +import bigframes_vendored.ibis.backends.sql.compilers.bigquery as bigquery + +__all__ = [ + "bigquery", +] diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py new file mode 100644 index 0000000000..c74de82099 --- /dev/null +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py @@ -0,0 +1,1660 @@ +# Contains code from https://github.com/ibis-project/ibis/blob/main/ibis/backends/sql/compilers/base.py + +from __future__ import annotations + +import abc +import calendar +from functools import partial, reduce +import itertools +import math +import operator +import string +from typing import Any, ClassVar, TYPE_CHECKING + +from bigframes_vendored.ibis.backends.sql.rewrites import ( + add_one_to_nth_value_input, + add_order_by_to_empty_ranking_window_functions, + empty_in_values_right_side, + FirstValue, + LastValue, + lower_bucket, + lower_capitalize, + lower_sample, + one_to_zero_index, + sqlize, +) +from bigframes_vendored.ibis.expr.rewrites import lower_stringslice +import ibis.common.exceptions as com +import ibis.common.patterns as pats +from ibis.config import options +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +from ibis.expr.operations.udf import InputType +from public import public +import sqlglot as sg +import sqlglot.expressions as sge + +try: + from sqlglot.expressions import Alter +except ImportError: + from sqlglot.expressions import AlterTable +else: + + def AlterTable(*args, kind="TABLE", **kwargs): + return Alter(*args, kind=kind, **kwargs) + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Mapping + + from bigframes_vendored.ibis.backends.bigquery.datatypes import SqlglotType + import ibis.expr.schema as sch + import ibis.expr.types as ir + + +def get_leaf_classes(op): + for child_class in op.__subclasses__(): + if not child_class.__subclasses__(): + yield child_class + else: + yield from get_leaf_classes(child_class) + + +ALL_OPERATIONS = frozenset(get_leaf_classes(ops.Node)) + + +class AggGen: + """A descriptor for compiling aggregate functions. + + Common cases can be handled by setting configuration flags, + special cases should override the `aggregate` method directly. + + Parameters + ---------- + supports_filter + Whether the backend supports a FILTER clause in the aggregate. + Defaults to False. + supports_order_by + Whether the backend supports an ORDER BY clause in (relevant) + aggregates. Defaults to False. + """ + + class _Accessor: + """An internal type to handle getattr/getitem access.""" + + __slots__ = ("handler", "compiler") + + def __init__(self, handler: Callable, compiler: SQLGlotCompiler): + self.handler = handler + self.compiler = compiler + + def __getattr__(self, name: str) -> Callable: + return partial(self.handler, self.compiler, name) + + __getitem__ = __getattr__ + + __slots__ = ("supports_filter", "supports_order_by") + + def __init__( + self, *, supports_filter: bool = False, supports_order_by: bool = False + ): + self.supports_filter = supports_filter + self.supports_order_by = supports_order_by + + def __get__(self, instance, owner=None): + if instance is None: + return self + + return AggGen._Accessor(self.aggregate, instance) + + def aggregate( + self, + compiler: SQLGlotCompiler, + name: str, + *args: Any, + where: Any = None, + order_by: tuple = (), + ): + """Compile the specified aggregate. + + Parameters + ---------- + compiler + The backend's compiler. + name + The aggregate name (e.g. `"sum"`). + args + Any arguments to pass to the aggregate. + where + An optional column filter to apply before performing the aggregate. + order_by + Optional ordering keys to use to order the rows before performing + the aggregate. + """ + func = compiler.f[name] + + if order_by and not self.supports_order_by: + raise com.UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + f"not supported for the {compiler.dialect} backend" + ) + + if where is not None and not self.supports_filter: + args = tuple(compiler.if_(where, arg, NULL) for arg in args) + + if order_by and self.supports_order_by: + *rest, last = args + out = func(*rest, sge.Order(this=last, expressions=order_by)) + else: + out = func(*args) + + if where is not None and self.supports_filter: + out = sge.Filter(this=out, expression=sge.Where(this=where)) + + return out + + +class VarGen: + __slots__ = () + + def __getattr__(self, name: str) -> sge.Var: + return sge.Var(this=name) + + def __getitem__(self, key: str) -> sge.Var: + return sge.Var(this=key) + + +class AnonymousFuncGen: + __slots__ = () + + def __getattr__(self, name: str) -> Callable[..., sge.Anonymous]: + return lambda *args: sge.Anonymous( + this=name, expressions=list(map(sge.convert, args)) + ) + + def __getitem__(self, key: str) -> Callable[..., sge.Anonymous]: + return getattr(self, key) + + +class FuncGen: + __slots__ = ("namespace", "anon", "copy") + + def __init__(self, namespace: str | None = None, copy: bool = False) -> None: + self.namespace = namespace + self.anon = AnonymousFuncGen() + self.copy = copy + + def __getattr__(self, name: str) -> Callable[..., sge.Func]: + name = ".".join(filter(None, (self.namespace, name))) + return lambda *args, **kwargs: sg.func( + name, *map(sge.convert, args), **kwargs, copy=self.copy + ) + + def __getitem__(self, key: str) -> Callable[..., sge.Func]: + return getattr(self, key) + + def array(self, *args: Any) -> sge.Array: + if not args: + return sge.Array(expressions=[]) + + first, *rest = args + + if isinstance(first, sge.Select): + assert ( + not rest + ), "only one argument allowed when `first` is a select statement" + + return sge.Array(expressions=list(map(sge.convert, (first, *rest)))) + + def tuple(self, *args: Any) -> sge.Anonymous: + return self.anon.tuple(*args) + + def exists(self, query: sge.Expression) -> sge.Exists: + return sge.Exists(this=query) + + def concat(self, *args: Any) -> sge.Concat: + return sge.Concat(expressions=list(map(sge.convert, args))) + + def map(self, keys: Iterable, values: Iterable) -> sge.Map: + return sge.Map(keys=keys, values=values) + + +class ColGen: + __slots__ = ("table",) + + def __init__(self, table: str | None = None) -> None: + self.table = table + + def __getattr__(self, name: str) -> sge.Column: + return sg.column(name, table=self.table, copy=False) + + def __getitem__(self, key: str) -> sge.Column: + return sg.column(key, table=self.table, copy=False) + + +C = ColGen() +F = FuncGen() +NULL = sge.Null() +FALSE = sge.false() +TRUE = sge.true() +STAR = sge.Star() + + +def parenthesize_inputs(f): + """Decorate a translation rule to parenthesize inputs.""" + + def wrapper(self, op, *, left, right): + return f( + self, + op, + left=self._add_parens(op.left, left), + right=self._add_parens(op.right, right), + ) + + return wrapper + + +@public +class SQLGlotCompiler(abc.ABC): + __slots__ = "f", "v" + + agg = AggGen() + """A generator for handling aggregate functions""" + + rewrites: tuple[type[pats.Replace], ...] = ( + empty_in_values_right_side, + add_order_by_to_empty_ranking_window_functions, + one_to_zero_index, + add_one_to_nth_value_input, + ) + """A sequence of rewrites to apply to the expression tree before SQL-specific transforms.""" + + post_rewrites: tuple[type[pats.Replace], ...] = () + """A sequence of rewrites to apply to the expression tree after SQL-specific transforms.""" + + no_limit_value: sge.Null | None = None + """The value to use to indicate no limit.""" + + quoted: bool = True + """Whether to always quote identifiers.""" + + copy_func_args: bool = False + """Whether to copy function arguments when generating SQL.""" + + supports_qualify: bool = False + """Whether the backend supports the QUALIFY clause.""" + + NAN: ClassVar[sge.Expression] = sge.Cast( + this=sge.convert("NaN"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + """Backend's NaN literal.""" + + POS_INF: ClassVar[sge.Expression] = sge.Cast( + this=sge.convert("Inf"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + """Backend's positive infinity literal.""" + + NEG_INF: ClassVar[sge.Expression] = sge.Cast( + this=sge.convert("-Inf"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + """Backend's negative infinity literal.""" + + EXTRA_SUPPORTED_OPS: tuple[type[ops.Node], ...] = ( + ops.Project, + ops.Filter, + ops.Sort, + ops.WindowFunction, + ) + """A tuple of ops classes that are supported, but don't have explicit + `visit_*` methods (usually due to being handled by rewrite rules). Used by + `has_operation`""" + + UNSUPPORTED_OPS: tuple[type[ops.Node], ...] = () + """Tuple of operations the backend doesn't support.""" + + LOWERED_OPS: dict[type[ops.Node], pats.Replace | None] = { + ops.Bucket: lower_bucket, + ops.Capitalize: lower_capitalize, + ops.Sample: lower_sample, + ops.StringSlice: lower_stringslice, + } + """A mapping from an operation class to either a rewrite rule for rewriting that + operation to one composed of lower-level operations ("lowering"), or `None` to + remove an existing rewrite rule for that operation added in a base class""" + + SIMPLE_OPS = { + ops.Abs: "abs", + ops.Acos: "acos", + ops.All: "bool_and", + ops.Any: "bool_or", + ops.ApproxCountDistinct: "approx_distinct", + ops.ArgMax: "max_by", + ops.ArgMin: "min_by", + ops.ArrayContains: "array_contains", + ops.ArrayFlatten: "flatten", + ops.ArrayLength: "array_size", + ops.ArraySort: "array_sort", + ops.ArrayStringJoin: "array_to_string", + ops.Asin: "asin", + ops.Atan2: "atan2", + ops.Atan: "atan", + ops.Cos: "cos", + ops.Cot: "cot", + ops.Count: "count", + ops.CumeDist: "cume_dist", + ops.Date: "date", + ops.DateFromYMD: "datefromparts", + ops.Degrees: "degrees", + ops.DenseRank: "dense_rank", + ops.Exp: "exp", + FirstValue: "first_value", + ops.GroupConcat: "group_concat", + ops.IfElse: "if", + ops.IsInf: "isinf", + ops.IsNan: "isnan", + ops.JSONGetItem: "json_extract", + ops.LPad: "lpad", + LastValue: "last_value", + ops.Levenshtein: "levenshtein", + ops.Ln: "ln", + ops.Log10: "log", + ops.Log2: "log2", + ops.Lowercase: "lower", + ops.Map: "map", + ops.Median: "median", + ops.MinRank: "rank", + ops.NTile: "ntile", + ops.NthValue: "nth_value", + ops.NullIf: "nullif", + ops.PercentRank: "percent_rank", + ops.Pi: "pi", + ops.Power: "pow", + ops.RPad: "rpad", + ops.Radians: "radians", + ops.RegexSearch: "regexp_like", + ops.RegexSplit: "regexp_split", + ops.Repeat: "repeat", + ops.Reverse: "reverse", + ops.RowNumber: "row_number", + ops.Sign: "sign", + ops.Sin: "sin", + ops.Sqrt: "sqrt", + ops.StartsWith: "starts_with", + ops.StrRight: "right", + ops.StringAscii: "ascii", + ops.StringContains: "contains", + ops.StringLength: "length", + ops.StringReplace: "replace", + ops.StringSplit: "split", + ops.StringToDate: "str_to_date", + ops.StringToTimestamp: "str_to_time", + ops.Tan: "tan", + ops.Translate: "translate", + ops.Unnest: "explode", + ops.Uppercase: "upper", + } + + BINARY_INFIX_OPS = ( + # Binary operations + ops.Add, + ops.Subtract, + ops.Multiply, + ops.Divide, + ops.Modulus, + ops.Power, + # Comparisons + ops.GreaterEqual, + ops.Greater, + ops.LessEqual, + ops.Less, + ops.Equals, + ops.NotEquals, + # Boolean comparisons + ops.And, + ops.Or, + ops.Xor, + # Bitwise business + ops.BitwiseLeftShift, + ops.BitwiseRightShift, + ops.BitwiseAnd, + ops.BitwiseOr, + ops.BitwiseXor, + # Time arithmetic + ops.DateAdd, + ops.DateSub, + ops.DateDiff, + ops.TimestampAdd, + ops.TimestampSub, + ops.TimestampDiff, + # Interval Marginalia + ops.IntervalAdd, + ops.IntervalMultiply, + ops.IntervalSubtract, + ) + + NEEDS_PARENS = BINARY_INFIX_OPS + (ops.IsNull,) + + # Constructed dynamically in `__init_subclass__` from their respective + # UPPERCASE values to handle inheritance, do not modify directly here. + extra_supported_ops: ClassVar[frozenset[type[ops.Node]]] = frozenset() + lowered_ops: ClassVar[dict[type[ops.Node], pats.Replace]] = {} + + def __init__(self) -> None: + self.f = FuncGen(copy=self.__class__.copy_func_args) + self.v = VarGen() + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + def methodname(op: type) -> str: + assert isinstance(type(op), type), type(op) + return f"visit_{op.__name__}" + + def make_impl(op, target_name): + assert isinstance(type(op), type), type(op) + + if issubclass(op, ops.Reduction): + + def impl( + self, _, *, _name: str = target_name, where, order_by=(), **kw + ): + return self.agg[_name](*kw.values(), where=where, order_by=order_by) + + else: + + def impl(self, _, *, _name: str = target_name, **kw): + return self.f[_name](*kw.values()) + + return impl + + for op, target_name in cls.SIMPLE_OPS.items(): + setattr(cls, methodname(op), make_impl(op, target_name)) + + # unconditionally raise an exception for unsupported operations + # + # these *must* be defined after SIMPLE_OPS to handle compilers that + # subclass other compilers + for op in cls.UNSUPPORTED_OPS: + # change to visit_Unsupported in a follow up + # TODO: handle geoespatial ops as a separate case? + setattr(cls, methodname(op), cls.visit_Undefined) + + # raise on any remaining unsupported operations + for op in ALL_OPERATIONS: + name = methodname(op) + if not hasattr(cls, name): + setattr(cls, name, cls.visit_Undefined) + + # Amend `lowered_ops` and `extra_supported_ops` using their + # respective UPPERCASE classvar values. + extra_supported_ops = set(cls.extra_supported_ops) + lowered_ops = dict(cls.lowered_ops) + extra_supported_ops.update(cls.EXTRA_SUPPORTED_OPS) + for op_cls, rewrite in cls.LOWERED_OPS.items(): + if rewrite is not None: + lowered_ops[op_cls] = rewrite + extra_supported_ops.add(op_cls) + else: + lowered_ops.pop(op_cls, None) + extra_supported_ops.discard(op_cls) + cls.lowered_ops = lowered_ops + cls.extra_supported_ops = frozenset(extra_supported_ops) + + @property + @abc.abstractmethod + def dialect(self) -> str: + """Backend dialect.""" + + @property + @abc.abstractmethod + def type_mapper(self) -> type[SqlglotType]: + """The type mapper for the backend.""" + + def _compile_builtin_udf(self, udf_node: ops.ScalarUDF) -> None: # noqa: B027 + """No-op.""" + + def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None: + raise NotImplementedError( + f"Python UDFs are not supported in the {self.dialect} backend" + ) + + def _compile_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> None: + raise NotImplementedError( + f"PyArrow UDFs are not supported in the {self.dialect} backend" + ) + + def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: + raise NotImplementedError( + f"pandas UDFs are not supported in the {self.dialect} backend" + ) + + # Concrete API + + def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If: + return sge.If( + this=sge.convert(condition), + true=sge.convert(true), + false=None if false is None else sge.convert(false), + ) + + def cast(self, arg, to: dt.DataType) -> sge.Cast: + return sge.Cast( + this=sge.convert(arg), to=self.type_mapper.from_ibis(to), copy=False + ) + + def _prepare_params(self, params): + result = {} + for param, value in params.items(): + node = param.op() + if isinstance(node, ops.Alias): + node = node.arg + result[node] = value + return result + + def to_sqlglot( + self, + expr: ir.Expr, + *, + limit: str | None = None, + params: Mapping[ir.Expr, Any] | None = None, + ): + import ibis + + table_expr = expr.as_table() + + if limit == "default": + limit = ibis.options.sql.default_limit + if limit is not None: + table_expr = table_expr.limit(limit) + + if params is None: + params = {} + + sql = self.translate(table_expr.op(), params=params) + assert not isinstance(sql, sge.Subquery) + + if isinstance(sql, sge.Table): + sql = sg.select(STAR, copy=False).from_(sql, copy=False) + + assert not isinstance(sql, sge.Subquery) + return sql + + def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression: + """Translate an ibis operation to a sqlglot expression. + + Parameters + ---------- + op + An ibis operation + params + A mapping of expressions to concrete values + compiler + An instance of SQLGlotCompiler + translate_rel + Relation node translator + translate_val + Value node translator + + Returns + ------- + sqlglot.expressions.Expression + A sqlglot expression + + """ + # substitute parameters immediately to avoid having to define a + # ScalarParameter translation rule + params = self._prepare_params(params) + if self.lowered_ops: + op = op.replace(reduce(operator.or_, self.lowered_ops.values())) + op, ctes = sqlize( + op, + params=params, + rewrites=self.rewrites, + post_rewrites=self.post_rewrites, + fuse_selects=options.sql.fuse_selects, + ) + + aliases = {} + counter = itertools.count() + + def fn(node, _, **kwargs): + result = self.visit_node(node, **kwargs) + + # if it's not a relation then we don't need to do anything special + if node is op or not isinstance(node, ops.Relation): + return result + + # alias ops.Views to their explicitly assigned name otherwise generate + alias = node.name if isinstance(node, ops.View) else f"t{next(counter)}" + aliases[node] = alias + + alias = sg.to_identifier(alias, quoted=self.quoted) + if isinstance(result, sge.Subquery): + return result.as_(alias, quoted=self.quoted) + else: + try: + return result.subquery(alias, copy=False) + except AttributeError: + return result.as_(alias, quoted=self.quoted) + + # apply translate rules in topological order + results = op.map(fn) + + # get the root node as a sqlglot select statement + out = results[op] + if isinstance(out, sge.Table): + out = sg.select(STAR, copy=False).from_(out, copy=False) + elif isinstance(out, sge.Subquery): + out = out.this + + # add cte definitions to the select statement + for cte in ctes: + alias = sg.to_identifier(aliases[cte], quoted=self.quoted) + out = out.with_( + alias, as_=results[cte].this, dialect=self.dialect, copy=False + ) + + return out + + def visit_node(self, op: ops.Node, **kwargs): + if isinstance(op, ops.ScalarUDF): + return self.visit_ScalarUDF(op, **kwargs) + elif isinstance(op, ops.AggUDF): + return self.visit_AggUDF(op, **kwargs) + else: + method = getattr(self, f"visit_{type(op).__name__}", None) + if method is not None: + return method(op, **kwargs) + else: + raise com.OperationNotDefinedError( + f"No translation rule for {type(op).__name__}" + ) + + def visit_Field(self, op, *, rel, name): + return sg.column( + self._gen_valid_name(name), table=rel.alias_or_name, quoted=self.quoted + ) + + def visit_Cast(self, op, *, arg, to): + from_ = op.arg.dtype + + if from_.is_integer() and to.is_interval(): + return self._make_interval(arg, to.unit) + + return self.cast(arg, to) + + def visit_ScalarSubquery(self, op, *, rel): + return rel.this.subquery(copy=False) + + def visit_Alias(self, op, *, arg, name): + return arg + + def visit_Literal(self, op, *, value, dtype): + """Compile a literal value. + + This is the default implementation for compiling literal values. + + Most backends should not need to override this method unless they want + to handle NULL literals as well as every other type of non-null literal + including integers, floating point numbers, decimals, strings, etc. + + The logic here is: + + 1. If the value is None and the type is nullable, return NULL + 1. If the value is None and the type is not nullable, raise an error + 1. Call `visit_NonNullLiteral` method. + 1. If the previous returns `None`, call `visit_DefaultLiteral` method + else return the result of the previous step. + """ + if value is None: + if dtype.nullable: + return NULL if dtype.is_null() else self.cast(NULL, dtype) + raise com.UnsupportedOperationError( + f"Unsupported NULL for non-nullable type: {dtype!r}" + ) + else: + result = self.visit_NonNullLiteral(op, value=value, dtype=dtype) + if result is None: + return self.visit_DefaultLiteral(op, value=value, dtype=dtype) + return result + + def visit_NonNullLiteral(self, op, *, value, dtype): + """Compile a non-null literal differently than the default implementation. + + Most backends should implement this, but only when they need to handle + some non-null literal differently than the default implementation + (`visit_DefaultLiteral`). + + Return `None` from an override of this method to fall back to + `visit_DefaultLiteral`. + """ + return self.visit_DefaultLiteral(op, value=value, dtype=dtype) + + def visit_DefaultLiteral(self, op, *, value, dtype): + """Compile a literal with a non-null value. + + This is the default implementation for compiling non-null literals. + + Most backends should not need to override this method unless they want + to handle compiling every kind of non-null literal value. + """ + if dtype.is_integer(): + return sge.convert(value) + elif dtype.is_floating(): + if math.isnan(value): + return self.NAN + elif math.isinf(value): + return self.POS_INF if value > 0 else self.NEG_INF + return sge.convert(value) + elif dtype.is_decimal(): + return self.cast(str(value), dtype) + elif dtype.is_interval(): + return sge.Interval( + this=sge.convert(str(value)), + unit=sge.Var(this=dtype.resolution.upper()), + ) + elif dtype.is_boolean(): + return sge.Boolean(this=bool(value)) + elif dtype.is_string(): + return sge.convert(value) + elif dtype.is_inet() or dtype.is_macaddr(): + return sge.convert(str(value)) + elif dtype.is_timestamp() or dtype.is_time(): + return self.cast(value.isoformat(), dtype) + elif dtype.is_date(): + return self.f.datefromparts(value.year, value.month, value.day) + elif dtype.is_array(): + value_type = dtype.value_type + return self.f.array( + *( + self.visit_Literal( + ops.Literal(v, value_type), value=v, dtype=value_type + ) + for v in value + ) + ) + elif dtype.is_map(): + key_type = dtype.key_type + keys = self.f.array( + *( + self.visit_Literal( + ops.Literal(k, key_type), value=k, dtype=key_type + ) + for k in value.keys() + ) + ) + + value_type = dtype.value_type + values = self.f.array( + *( + self.visit_Literal( + ops.Literal(v, value_type), value=v, dtype=value_type + ) + for v in value.values() + ) + ) + + return self.f.map(keys, values) + elif dtype.is_struct(): + items = [ + self.visit_Literal( + ops.Literal(v, field_dtype), value=v, dtype=field_dtype + ).as_(k, quoted=self.quoted) + for field_dtype, (k, v) in zip(dtype.types, value.items()) + ] + return sge.Struct.from_arg_list(items) + elif dtype.is_uuid(): + return self.cast(str(value), dtype) + elif dtype.is_geospatial(): + args = [value.wkt] + if (srid := dtype.srid) is not None: + args.append(srid) + return self.f.st_geomfromtext(*args) + + raise NotImplementedError(f"Unsupported type: {dtype!r}") + + def visit_BitwiseNot(self, op, *, arg): + return sge.BitwiseNot(this=arg) + + ### Mathematical Calisthenics + + def visit_E(self, op): + return self.f.exp(1) + + def visit_Log(self, op, *, arg, base): + if base is None: + return self.f.ln(arg) + elif str(base) in ("2", "10"): + return self.f[f"log{base}"](arg) + else: + return self.f.ln(arg) / self.f.ln(base) + + def visit_Clip(self, op, *, arg, lower, upper): + if upper is not None: + arg = self.if_(arg.is_(NULL), arg, self.f.least(upper, arg)) + + if lower is not None: + arg = self.if_(arg.is_(NULL), arg, self.f.greatest(lower, arg)) + + return arg + + def visit_FloorDivide(self, op, *, left, right): + return self.cast(self.f.floor(left / right), op.dtype) + + def visit_Ceil(self, op, *, arg): + return self.cast(self.f.ceil(arg), op.dtype) + + def visit_Floor(self, op, *, arg): + return self.cast(self.f.floor(arg), op.dtype) + + def visit_Round(self, op, *, arg, digits): + if digits is not None: + return sge.Round(this=arg, decimals=digits) + return sge.Round(this=arg) + + ### Random Noise + + def visit_RandomScalar(self, op, **kwargs): + return self.f.rand() + + def visit_RandomUUID(self, op, **kwargs): + return self.f.uuid() + + ### Dtype Dysmorphia + + def visit_TryCast(self, op, *, arg, to): + return sge.TryCast(this=arg, to=self.type_mapper.from_ibis(to)) + + ### Comparator Conundrums + + def visit_Between(self, op, *, arg, lower_bound, upper_bound): + return sge.Between(this=arg, low=lower_bound, high=upper_bound) + + def visit_Negate(self, op, *, arg): + return -sge.paren(arg, copy=False) + + def visit_Not(self, op, *, arg): + if isinstance(arg, sge.Filter): + return sge.Filter( + this=sg.not_(arg.this, copy=False), expression=arg.expression + ) + return sg.not_(sge.paren(arg, copy=False)) + + ### Timey McTimeFace + + def visit_Time(self, op, *, arg): + return self.cast(arg, to=dt.time) + + def visit_TimestampNow(self, op): + return sge.CurrentTimestamp() + + def visit_DateNow(self, op): + return sge.CurrentDate() + + def visit_Strftime(self, op, *, arg, format_str): + return sge.TimeToStr(this=arg, format=format_str) + + def visit_ExtractEpochSeconds(self, op, *, arg): + return self.f.epoch(self.cast(arg, dt.timestamp)) + + def visit_ExtractYear(self, op, *, arg): + return self.f.extract(self.v.year, arg) + + def visit_ExtractMonth(self, op, *, arg): + return self.f.extract(self.v.month, arg) + + def visit_ExtractDay(self, op, *, arg): + return self.f.extract(self.v.day, arg) + + def visit_ExtractDayOfYear(self, op, *, arg): + return self.f.extract(self.v.dayofyear, arg) + + def visit_ExtractQuarter(self, op, *, arg): + return self.f.extract(self.v.quarter, arg) + + def visit_ExtractWeekOfYear(self, op, *, arg): + return self.f.extract(self.v.week, arg) + + def visit_ExtractHour(self, op, *, arg): + return self.f.extract(self.v.hour, arg) + + def visit_ExtractMinute(self, op, *, arg): + return self.f.extract(self.v.minute, arg) + + def visit_ExtractSecond(self, op, *, arg): + return self.f.extract(self.v.second, arg) + + def visit_TimestampTruncate(self, op, *, arg, unit): + unit_mapping = { + "Y": "year", + "Q": "quarter", + "M": "month", + "W": "week", + "D": "day", + "h": "hour", + "m": "minute", + "s": "second", + "ms": "ms", + "us": "us", + } + + if (raw_unit := unit_mapping.get(unit.short)) is None: + raise com.UnsupportedOperationError( + f"Unsupported truncate unit {unit.short!r}" + ) + + return self.f.date_trunc(raw_unit, arg) + + def visit_DateTruncate(self, op, *, arg, unit): + return self.visit_TimestampTruncate(op, arg=arg, unit=unit) + + def visit_TimeTruncate(self, op, *, arg, unit): + return self.visit_TimestampTruncate(op, arg=arg, unit=unit) + + def visit_DayOfWeekIndex(self, op, *, arg): + return (self.f.dayofweek(arg) + 6) % 7 + + def visit_DayOfWeekName(self, op, *, arg): + # day of week number is 0-indexed + # Sunday == 0 + # Saturday == 6 + return sge.Case( + this=(self.f.dayofweek(arg) + 6) % 7, + ifs=list(itertools.starmap(self.if_, enumerate(calendar.day_name))), + ) + + def _make_interval(self, arg, unit): + return sge.Interval(this=arg, unit=self.v[unit.singular]) + + def visit_IntervalFromInteger(self, op, *, arg, unit): + return self._make_interval(arg, unit) + + ### String Instruments + def visit_Strip(self, op, *, arg): + return self.f.trim(arg, string.whitespace) + + def visit_RStrip(self, op, *, arg): + return self.f.rtrim(arg, string.whitespace) + + def visit_LStrip(self, op, *, arg): + return self.f.ltrim(arg, string.whitespace) + + def visit_Substring(self, op, *, arg, start, length): + if isinstance(op.length, ops.Literal) and (value := op.length.value) < 0: + raise com.IbisInputError( + f"Length parameter must be a non-negative value; got {value}" + ) + start += 1 + start = self.if_(start >= 1, start, start + self.f.length(arg)) + if length is None: + return self.f.substring(arg, start) + return self.f.substring(arg, start, length) + + def visit_StringFind(self, op, *, arg, substr, start, end): + if end is not None: + raise com.UnsupportedOperationError( + "String find doesn't support `end` argument" + ) + + if start is not None: + arg = self.f.substr(arg, start + 1) + pos = self.f.strpos(arg, substr) + return self.if_(pos > 0, pos + start, 0) + + return self.f.strpos(arg, substr) + + def visit_RegexReplace(self, op, *, arg, pattern, replacement): + return self.f.regexp_replace(arg, pattern, replacement, "g") + + def visit_StringConcat(self, op, *, arg): + return self.f.concat(*arg) + + def visit_StringJoin(self, op, *, sep, arg): + return self.f.concat_ws(sep, *arg) + + def visit_StringSQLLike(self, op, *, arg, pattern, escape): + return arg.like(pattern) + + def visit_StringSQLILike(self, op, *, arg, pattern, escape): + return arg.ilike(pattern) + + ### NULL PLAYER CHARACTER + def visit_IsNull(self, op, *, arg): + return arg.is_(NULL) + + def visit_NotNull(self, op, *, arg): + return arg.is_(sg.not_(NULL, copy=False)) + + def visit_InValues(self, op, *, value, options): + return value.isin(*options) + + ### Counting + + def visit_CountDistinct(self, op, *, arg, where): + return self.agg.count(sge.Distinct(expressions=[arg]), where=where) + + def visit_CountDistinctStar(self, op, *, arg, where): + return self.agg.count(sge.Distinct(expressions=[STAR]), where=where) + + def visit_CountStar(self, op, *, arg, where): + return self.agg.count(STAR, where=where) + + def visit_Sum(self, op, *, arg, where): + if op.arg.dtype.is_boolean(): + arg = self.cast(arg, dt.int32) + return self.agg.sum(arg, where=where) + + def visit_Mean(self, op, *, arg, where): + if op.arg.dtype.is_boolean(): + arg = self.cast(arg, dt.int32) + return self.agg.avg(arg, where=where) + + def visit_Min(self, op, *, arg, where): + if op.arg.dtype.is_boolean(): + return self.agg.bool_and(arg, where=where) + return self.agg.min(arg, where=where) + + def visit_Max(self, op, *, arg, where): + if op.arg.dtype.is_boolean(): + return self.agg.bool_or(arg, where=where) + return self.agg.max(arg, where=where) + + ### Stats + + def visit_VarianceStandardDevCovariance(self, op, *, how, where, **kw): + hows = {"sample": "samp", "pop": "pop"} + funcs = { + ops.Variance: "var", + ops.StandardDev: "stddev", + ops.Covariance: "covar", + } + + args = [] + + for oparg, arg in zip(op.args, kw.values()): + if (arg_dtype := oparg.dtype).is_boolean(): + arg = self.cast(arg, dt.Int32(nullable=arg_dtype.nullable)) + args.append(arg) + + funcname = f"{funcs[type(op)]}_{hows[how]}" + return self.agg[funcname](*args, where=where) + + visit_Variance = ( + visit_StandardDev + ) = visit_Covariance = visit_VarianceStandardDevCovariance + + def visit_SimpleCase(self, op, *, base=None, cases, results, default): + return sge.Case( + this=base, ifs=list(map(self.if_, cases, results)), default=default + ) + + visit_SearchedCase = visit_SimpleCase + + def visit_ExistsSubquery(self, op, *, rel): + select = rel.this.select(1, append=False) + return self.f.exists(select) + + def visit_InSubquery(self, op, *, rel, needle): + query = rel.this + if not isinstance(query, sge.Select): + query = sg.select(STAR).from_(query) + return needle.isin(query=query) + + def visit_Array(self, op, *, exprs): + return self.f.array(*exprs) + + def visit_StructColumn(self, op, *, names, values): + return sge.Struct.from_arg_list( + [value.as_(name, quoted=self.quoted) for name, value in zip(names, values)] + ) + + def visit_StructField(self, op, *, arg, field): + return sge.Dot(this=arg, expression=sg.to_identifier(field, quoted=self.quoted)) + + def visit_IdenticalTo(self, op, *, left, right): + return sge.NullSafeEQ(this=left, expression=right) + + def visit_Greatest(self, op, *, arg): + return self.f.greatest(*arg) + + def visit_Least(self, op, *, arg): + return self.f.least(*arg) + + def visit_Coalesce(self, op, *, arg): + return self.f.coalesce(*arg) + + ### Ordering and window functions + + def visit_SortKey(self, op, *, expr, ascending: bool, nulls_first: bool): + return sge.Ordered(this=expr, desc=not ascending, nulls_first=nulls_first) + + def visit_ApproxMedian(self, op, *, arg, where): + return self.agg.approx_quantile(arg, 0.5, where=where) + + def visit_WindowBoundary(self, op, *, value, preceding): + # TODO: bit of a hack to return a dict, but there's no sqlglot expression + # that corresponds to _only_ this information + return {"value": value, "side": "preceding" if preceding else "following"} + + def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by): + if start is None: + start = {} + if end is None: + end = {} + + start_value = start.get("value", "UNBOUNDED") + start_side = start.get("side", "PRECEDING") + end_value = end.get("value", "UNBOUNDED") + end_side = end.get("side", "FOLLOWING") + + if getattr(start_value, "this", None) == "0": + start_value = "CURRENT ROW" + start_side = None + + if getattr(end_value, "this", None) == "0": + end_value = "CURRENT ROW" + end_side = None + + spec = sge.WindowSpec( + kind=how.upper(), + start=start_value, + start_side=start_side, + end=end_value, + end_side=end_side, + over="OVER", + ) + order = sge.Order(expressions=order_by) if order_by else None + + spec = self._minimize_spec(op.start, op.end, spec) + + return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) + + @staticmethod + def _minimize_spec(start, end, spec): + return spec + + def visit_LagLead(self, op, *, arg, offset, default): + args = [arg] + + if default is not None: + if offset is None: + offset = 1 + + args.append(offset) + args.append(default) + elif offset is not None: + args.append(offset) + + return self.f[type(op).__name__.lower()](*args) + + visit_Lag = visit_Lead = visit_LagLead + + def visit_Argument(self, op, *, name: str, shape, dtype): + return sg.to_identifier(op.param) + + def visit_RowID(self, op, *, table): + return sg.column( + op.name, table=table.alias_or_name, quoted=self.quoted, copy=False + ) + + # TODO(kszucs): this should be renamed to something UDF related + def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str: + # for builtin functions use the exact function name, otherwise use the + # generated name to handle the case of redefinition + funcname = ( + op.__func_name__ + if op.__input_type__ == InputType.BUILTIN + else type(op).__name__ + ) + + # not actually a table, but easier to quote individual namespace + # components this way + namespace = op.__udf_namespace__ + return sg.table(funcname, db=namespace.database, catalog=namespace.catalog).sql( + self.dialect + ) + + def visit_ScalarUDF(self, op, **kw): + return self.f[self.__sql_name__(op)](*kw.values()) + + def visit_AggUDF(self, op, *, where, **kw): + return self.agg[self.__sql_name__(op)](*kw.values(), where=where) + + def visit_TimestampDelta(self, op, *, part, left, right): + # dialect is necessary due to sqlglot's default behavior + # of `part` coming last + return sge.DateDiff( + this=left, expression=right, unit=part, dialect=self.dialect + ) + + visit_TimeDelta = visit_DateDelta = visit_TimestampDelta + + def visit_TimestampBucket(self, op, *, arg, interval, offset): + origin = self.f.cast("epoch", self.type_mapper.from_ibis(dt.timestamp)) + if offset is not None: + origin += offset + return self.f.time_bucket(interval, arg, origin) + + def visit_ArrayConcat(self, op, *, arg): + return sge.ArrayConcat(this=arg[0], expressions=list(arg[1:])) + + ## relations + + @staticmethod + def _gen_valid_name(name: str) -> str: + """Generate a valid name for a value expression. + + Override this method if the dialect has restrictions on valid + identifiers even when quoted. + + See the BigQuery backend's implementation for an example. + """ + return name + + def _cleanup_names(self, exprs: Mapping[str, sge.Expression]): + """Compose `_gen_valid_name` and `_dedup_name` to clean up names in projections.""" + + for name, value in exprs.items(): + name = self._gen_valid_name(name) + if isinstance(value, sge.Column) and name == value.name: + # don't alias columns that are already named the same as their alias + yield value + else: + yield value.as_(name, quoted=self.quoted, copy=False) + + def visit_Select( + self, op, *, parent, selections, predicates, qualified, sort_keys, distinct + ): + # if we've constructed a useless projection return the parent relation + if not (selections or predicates or qualified or sort_keys or distinct): + return parent + + result = parent + + if selections: + # if there are `qualify` predicates then sqlglot adds a hidden + # column to implement the functionality if the dialect doesn't + # support it + # + # using STAR in that case would lead to an extra column, so in that + # case we have to spell out the columns + if op.is_star_selection() and (not qualified or self.supports_qualify): + fields = [STAR] + else: + fields = self._cleanup_names(selections) + result = sg.select(*fields, copy=False).from_(result, copy=False) + + if predicates: + result = result.where(*predicates, copy=False) + + if qualified: + result = result.qualify(*qualified, copy=False) + + if sort_keys: + result = result.order_by(*sort_keys, copy=False) + + if distinct: + result = result.distinct() + + return result + + def visit_DummyTable(self, op, *, values): + return sg.select(*self._cleanup_names(values), copy=False) + + def visit_UnboundTable( + self, op, *, name: str, schema: sch.Schema, namespace: ops.Namespace + ) -> sg.Table: + return sg.table( + name, db=namespace.database, catalog=namespace.catalog, quoted=self.quoted + ) + + def visit_InMemoryTable( + self, op, *, name: str, schema: sch.Schema, data + ) -> sg.Table: + return sg.table(name, quoted=self.quoted) + + def visit_DatabaseTable( + self, + op, + *, + name: str, + schema: sch.Schema, + source: Any, + namespace: ops.Namespace, + ) -> sg.Table: + return sg.table( + name, db=namespace.database, catalog=namespace.catalog, quoted=self.quoted + ) + + def visit_SelfReference(self, op, *, parent, identifier): + return parent + + visit_JoinReference = visit_SelfReference + + def visit_JoinChain(self, op, *, first, rest, values): + result = sg.select(*self._cleanup_names(values), copy=False).from_( + first, copy=False + ) + + for link in rest: + if isinstance(link, sge.Alias): + link = link.this + result = result.join(link, copy=False) + return result + + def visit_JoinLink(self, op, *, how, table, predicates): + sides = { + "inner": None, + "left": "left", + "right": "right", + "semi": "left", + "anti": "left", + "cross": None, + "outer": "full", + "asof": "asof", + "any_left": "left", + "any_inner": None, + "positional": None, + } + kinds = { + "any_left": "any", + "any_inner": "any", + "asof": "left", + "inner": "inner", + "left": "outer", + "right": "outer", + "semi": "semi", + "anti": "anti", + "cross": "cross", + "outer": "outer", + "positional": "positional", + } + assert predicates or how in { + "cross", + "positional", + }, "expected non-empty predicates when not a cross join" + on = sg.and_(*predicates) if predicates else None + return sge.Join(this=table, side=sides[how], kind=kinds[how], on=on) + + @staticmethod + def _generate_groups(groups): + return map(sge.convert, range(1, len(groups) + 1)) + + def visit_Aggregate(self, op, *, parent, groups, metrics): + sel = sg.select( + *self._cleanup_names(groups), *self._cleanup_names(metrics), copy=False + ).from_(parent, copy=False) + + if groups: + sel = sel.group_by(*self._generate_groups(groups.values()), copy=False) + + return sel + + @classmethod + def _add_parens(cls, op, sg_expr): + if isinstance(op, cls.NEEDS_PARENS): + return sge.paren(sg_expr, copy=False) + return sg_expr + + def visit_Union(self, op, *, left, right, distinct): + if isinstance(left, (sge.Table, sge.Subquery)): + left = sg.select(STAR, copy=False).from_(left, copy=False) + + if isinstance(right, (sge.Table, sge.Subquery)): + right = sg.select(STAR, copy=False).from_(right, copy=False) + + return sg.union( + left.args.get("this", left), + right.args.get("this", right), + distinct=distinct, + copy=False, + ) + + def visit_Intersection(self, op, *, left, right, distinct): + if isinstance(left, (sge.Table, sge.Subquery)): + left = sg.select(STAR, copy=False).from_(left, copy=False) + + if isinstance(right, (sge.Table, sge.Subquery)): + right = sg.select(STAR, copy=False).from_(right, copy=False) + + return sg.intersect( + left.args.get("this", left), + right.args.get("this", right), + distinct=distinct, + copy=False, + ) + + def visit_Difference(self, op, *, left, right, distinct): + if isinstance(left, (sge.Table, sge.Subquery)): + left = sg.select(STAR, copy=False).from_(left, copy=False) + + if isinstance(right, (sge.Table, sge.Subquery)): + right = sg.select(STAR, copy=False).from_(right, copy=False) + + return sg.except_( + left.args.get("this", left), + right.args.get("this", right), + distinct=distinct, + copy=False, + ) + + def visit_Limit(self, op, *, parent, n, offset): + # push limit/offset into subqueries + if isinstance(parent, sge.Subquery) and parent.this.args.get("limit") is None: + result = parent.this.copy() + alias = parent.alias + else: + result = sg.select(STAR, copy=False).from_(parent, copy=False) + alias = None + + if isinstance(n, int): + result = result.limit(n, copy=False) + elif n is not None: + result = result.limit( + sg.select(n, copy=False).from_(parent, copy=False).subquery(copy=False), + copy=False, + ) + else: + assert n is None, n + if self.no_limit_value is not None: + result = result.limit(self.no_limit_value, copy=False) + + assert offset is not None, "offset is None" + + if not isinstance(offset, int): + skip = offset + skip = ( + sg.select(skip, copy=False) + .from_(parent, copy=False) + .subquery(copy=False) + ) + elif not offset: + if alias is not None: + return result.subquery(alias, copy=False) + return result + else: + skip = offset + + result = result.offset(skip, copy=False) + if alias is not None: + return result.subquery(alias, copy=False) + return result + + def visit_CTE(self, op, *, parent): + return sg.table(parent.alias_or_name, quoted=self.quoted) + + def visit_View(self, op, *, child, name: str): + if isinstance(child, sge.Table): + child = sg.select(STAR, copy=False).from_(child, copy=False) + else: + child = child.copy() + + if isinstance(child, sge.Subquery): + return child.as_(name, quoted=self.quoted) + else: + try: + return child.subquery(name, copy=False) + except AttributeError: + return child.as_(name, quoted=self.quoted) + + def visit_SQLStringView(self, op, *, query: str, child, schema): + return sg.parse_one(query, read=self.dialect) + + def visit_SQLQueryResult(self, op, *, query, schema, source): + return sg.parse_one(query, dialect=self.dialect).subquery(copy=False) + + def visit_RegexExtract(self, op, *, arg, pattern, index): + return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect) + + @parenthesize_inputs + def visit_Add(self, op, *, left, right): + return sge.Add(this=left, expression=right) + + visit_DateAdd = visit_TimestampAdd = visit_IntervalAdd = visit_Add + + @parenthesize_inputs + def visit_Subtract(self, op, *, left, right): + return sge.Sub(this=left, expression=right) + + visit_DateSub = ( + visit_DateDiff + ) = ( + visit_TimestampSub + ) = visit_TimestampDiff = visit_IntervalSubtract = visit_Subtract + + @parenthesize_inputs + def visit_Multiply(self, op, *, left, right): + return sge.Mul(this=left, expression=right) + + visit_IntervalMultiply = visit_Multiply + + @parenthesize_inputs + def visit_Divide(self, op, *, left, right): + return sge.Div(this=left, expression=right) + + @parenthesize_inputs + def visit_Modulus(self, op, *, left, right): + return sge.Mod(this=left, expression=right) + + @parenthesize_inputs + def visit_Power(self, op, *, left, right): + return sge.Pow(this=left, expression=right) + + @parenthesize_inputs + def visit_GreaterEqual(self, op, *, left, right): + return sge.GTE(this=left, expression=right) + + @parenthesize_inputs + def visit_Greater(self, op, *, left, right): + return sge.GT(this=left, expression=right) + + @parenthesize_inputs + def visit_LessEqual(self, op, *, left, right): + return sge.LTE(this=left, expression=right) + + @parenthesize_inputs + def visit_Less(self, op, *, left, right): + return sge.LT(this=left, expression=right) + + @parenthesize_inputs + def visit_Equals(self, op, *, left, right): + return sge.EQ(this=left, expression=right) + + @parenthesize_inputs + def visit_NotEquals(self, op, *, left, right): + return sge.NEQ(this=left, expression=right) + + @parenthesize_inputs + def visit_And(self, op, *, left, right): + return sge.And(this=left, expression=right) + + @parenthesize_inputs + def visit_Or(self, op, *, left, right): + return sge.Or(this=left, expression=right) + + @parenthesize_inputs + def visit_Xor(self, op, *, left, right): + return sge.Xor(this=left, expression=right) + + @parenthesize_inputs + def visit_BitwiseLeftShift(self, op, *, left, right): + return sge.BitwiseLeftShift(this=left, expression=right) + + @parenthesize_inputs + def visit_BitwiseRightShift(self, op, *, left, right): + return sge.BitwiseRightShift(this=left, expression=right) + + @parenthesize_inputs + def visit_BitwiseAnd(self, op, *, left, right): + return sge.BitwiseAnd(this=left, expression=right) + + @parenthesize_inputs + def visit_BitwiseOr(self, op, *, left, right): + return sge.BitwiseOr(this=left, expression=right) + + @parenthesize_inputs + def visit_BitwiseXor(self, op, *, left, right): + return sge.BitwiseXor(this=left, expression=right) + + def visit_Undefined(self, op, **_): + raise com.OperationNotDefinedError( + f"Compilation rule for {type(op).__name__!r} operation is not defined" + ) + + def visit_Unsupported(self, op, **_): + raise com.UnsupportedOperationError( + f"{type(op).__name__!r} operation is not supported in the {self.dialect} backend" + ) + + def visit_DropColumns(self, op, *, parent, columns_to_drop): + # the generated query will be huge for wide tables + # + # TODO: figure out a way to produce an IR that only contains exactly + # what is used + parent_alias = parent.alias_or_name + quoted = self.quoted + columns_to_keep = ( + sg.column(column, table=parent_alias, quoted=quoted) + for column in op.schema.names + ) + return sg.select(*columns_to_keep).from_(parent) + + def add_query_to_expr(self, *, name: str, table: ir.Table, query: str) -> str: + dialect = self.dialect + + compiled_ibis_expr = self.to_sqlglot(table) + + # pull existing CTEs from the compiled Ibis expression and combine them + # with the new query + parsed = reduce( + lambda parsed, cte: parsed.with_(cte.args["alias"], as_=cte.args["this"]), + compiled_ibis_expr.ctes, + sg.parse_one(query, read=dialect), + ) + + # remove all ctes from the compiled expression, since they're now in + # our larger expression + compiled_ibis_expr.args.pop("with", None) + + # add the new str query as a CTE + parsed = parsed.with_( + sg.to_identifier(name, quoted=self.quoted), as_=compiled_ibis_expr + ) + + # generate the SQL string + return parsed.sql(dialect) + + def _make_sample_backwards_compatible(self, *, sample, parent): + # sample was changed to be owned by the table being sampled in 25.17.0 + # + # this is a small workaround for backwards compatibility + if "this" in sample.__class__.arg_types: + sample.args["this"] = parent + else: + parent.args["sample"] = sample + return sg.select(STAR).from_(parent) + + +# `__init_subclass__` is uncalled for subclasses - we manually call it here to +# autogenerate the base class implementations as well. +SQLGlotCompiler.__init_subclass__() diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py new file mode 100644 index 0000000000..fc8d93a433 --- /dev/null +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -0,0 +1,1114 @@ +# Contains code from https://github.com/ibis-project/ibis/blob/main/ibis/backends/sql/compilers/bigquery/__init__.py + +"""Module to convert from Ibis expression to SQL string.""" + +from __future__ import annotations + +import decimal +import math +import re +from typing import Any, TYPE_CHECKING + +from bigframes_vendored.ibis.backends.bigquery.datatypes import ( + BigQueryType, + BigQueryUDFType, +) +from bigframes_vendored.ibis.backends.sql.compilers.base import ( + AggGen, + NULL, + SQLGlotCompiler, + STAR, +) +from bigframes_vendored.ibis.backends.sql.rewrites import ( + exclude_unsupported_window_frame_from_ops, + exclude_unsupported_window_frame_from_rank, + exclude_unsupported_window_frame_from_row_number, + split_select_distinct_with_order_by, +) +from ibis import util +from ibis.backends.sql.compilers.bigquery.udf.core import PythonToJavaScriptTranslator +import ibis.common.exceptions as com +from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +import sqlglot as sg +from sqlglot.dialects import BigQuery +import sqlglot.expressions as sge + +if TYPE_CHECKING: + from collections.abc import Mapping + + import ibis.expr.types as ir + +_NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+') + + +_MEMTABLE_PATTERN = re.compile( + r"^_?ibis_(?:[A-Za-z_][A-Za-z_0-9]*)_memtable_[a-z0-9]{26}$" +) + + +def _qualify_memtable( + node: sge.Expression, *, dataset: str | None, project: str | None +) -> sge.Expression: + """Add a BigQuery dataset and project to memtable references.""" + if isinstance(node, sge.Table) and _MEMTABLE_PATTERN.match(node.name) is not None: + node.args["db"] = dataset + node.args["catalog"] = project + # make sure to quote table location + node = _force_quote_table(node) + return node + + +def _remove_null_ordering_from_unsupported_window( + node: sge.Expression, +) -> sge.Expression: + """Remove null ordering in window frame clauses not supported by BigQuery. + + BigQuery has only partial support for NULL FIRST/LAST in RANGE windows so + we remove it from any window frame clause that doesn't support it. + + Here's the support matrix: + + ✅ sum(x) over (order by y desc nulls last) + 🚫 sum(x) over (order by y asc nulls last) + ✅ sum(x) over (order by y asc nulls first) + 🚫 sum(x) over (order by y desc nulls first) + """ + if isinstance(node, sge.Window): + order = node.args.get("order") + if order is not None: + for key in order.args["expressions"]: + kargs = key.args + if kargs.get("desc") is True and kargs.get("nulls_first", False): + kargs["nulls_first"] = False + elif kargs.get("desc") is False and not kargs.setdefault( + "nulls_first", True + ): + kargs["nulls_first"] = True + return node + + +def _force_quote_table(table: sge.Table) -> sge.Table: + """Force quote all the parts of a bigquery path. + + The BigQuery identifier quoting semantics are bonkers + https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers + + my-table is OK, but not mydataset.my-table + + mytable-287 is OK, but not mytable-287a + + Just quote everything. + """ + for key in ("this", "db", "catalog"): + if (val := table.args[key]) is not None: + if isinstance(val, sg.exp.Identifier) and not val.quoted: + val.args["quoted"] = True + else: + table.args[key] = sg.to_identifier(val, quoted=True) + return table + + +class BigQueryCompiler(SQLGlotCompiler): + dialect = BigQuery + type_mapper = BigQueryType + udf_type_mapper = BigQueryUDFType + + agg = AggGen(supports_order_by=True) + + rewrites = ( + exclude_unsupported_window_frame_from_ops, + exclude_unsupported_window_frame_from_row_number, + exclude_unsupported_window_frame_from_rank, + *SQLGlotCompiler.rewrites, + ) + post_rewrites = (split_select_distinct_with_order_by,) + + supports_qualify = True + + UNSUPPORTED_OPS = ( + ops.DateDiff, + ops.ExtractAuthority, + ops.ExtractUserInfo, + ops.FindInSet, + ops.Median, + ops.RegexSplit, + ops.RowID, + ops.TimestampDiff, + ) + + NAN = sge.Cast( + this=sge.convert("NaN"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + POS_INF = sge.Cast( + this=sge.convert("Infinity"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + NEG_INF = sge.Cast( + this=sge.convert("-Infinity"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + + SIMPLE_OPS = { + ops.Arbitrary: "any_value", + ops.StringAscii: "ascii", + ops.BitAnd: "bit_and", + ops.BitOr: "bit_or", + ops.BitXor: "bit_xor", + ops.DateFromYMD: "date", + ops.Divide: "ieee_divide", + ops.EndsWith: "ends_with", + ops.GeoArea: "st_area", + ops.GeoAsBinary: "st_asbinary", + ops.GeoAsText: "st_astext", + ops.GeoAzimuth: "st_azimuth", + ops.GeoBuffer: "st_buffer", + ops.GeoCentroid: "st_centroid", + ops.GeoContains: "st_contains", + ops.GeoCoveredBy: "st_coveredby", + ops.GeoCovers: "st_covers", + ops.GeoDWithin: "st_dwithin", + ops.GeoDifference: "st_difference", + ops.GeoDisjoint: "st_disjoint", + ops.GeoDistance: "st_distance", + ops.GeoEndPoint: "st_endpoint", + ops.GeoEquals: "st_equals", + ops.GeoGeometryType: "st_geometrytype", + ops.GeoIntersection: "st_intersection", + ops.GeoIntersects: "st_intersects", + ops.GeoLength: "st_length", + ops.GeoMaxDistance: "st_maxdistance", + ops.GeoNPoints: "st_numpoints", + ops.GeoPerimeter: "st_perimeter", + ops.GeoPoint: "st_geogpoint", + ops.GeoPointN: "st_pointn", + ops.GeoStartPoint: "st_startpoint", + ops.GeoTouches: "st_touches", + ops.GeoUnaryUnion: "st_union_agg", + ops.GeoUnion: "st_union", + ops.GeoWithin: "st_within", + ops.GeoX: "st_x", + ops.GeoY: "st_y", + ops.Hash: "farm_fingerprint", + ops.IsInf: "is_inf", + ops.IsNan: "is_nan", + ops.Log10: "log10", + ops.LPad: "lpad", + ops.RPad: "rpad", + ops.Levenshtein: "edit_distance", + ops.Modulus: "mod", + ops.RegexReplace: "regexp_replace", + ops.RegexSearch: "regexp_contains", + ops.Time: "time", + ops.TimeFromHMS: "time_from_parts", + ops.TimestampNow: "current_timestamp", + ops.ExtractHost: "net.host", + } + + def to_sqlglot( + self, + expr: ir.Expr, + *, + limit: str | None = None, + params: Mapping[ir.Expr, Any] | None = None, + session_dataset_id: str | None = None, + session_project: str | None = None, + ) -> Any: + """Compile an Ibis expression. + + Parameters + ---------- + expr + Ibis expression + limit + For expressions yielding result sets; retrieve at most this number + of values/rows. Overrides any limit already set on the expression. + params + Named unbound parameters + session_dataset_id + Optional dataset ID to qualify memtable references. + session_project + Optional project ID to qualify memtable references. + + Returns + ------- + Any + The output of compilation. The type of this value depends on the + backend. + + """ + sql = super().to_sqlglot(expr, limit=limit, params=params) + + table_expr = expr.as_table() + geocols = table_expr.schema().geospatial + + result = sql.transform( + _qualify_memtable, + dataset=session_dataset_id, + project=session_project, + ).transform(_remove_null_ordering_from_unsupported_window) + + if geocols: + # if there are any geospatial columns, we have to convert them to WKB, + # so interactive mode knows how to display them + # + # by default bigquery returns data to python as WKT, and there's really + # no point in supporting both if we don't need to. + quoted = self.quoted + result = sg.select( + sge.Star( + replace=[ + self.f.st_asbinary(sg.column(col, quoted=quoted)).as_( + col, quoted=quoted + ) + for col in geocols + ] + ) + ).from_(result.subquery()) + + sources = [] + + for udf_node in table_expr.op().find(ops.ScalarUDF): + compile_func = getattr( + self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" + ) + if sql := compile_func(udf_node): + sources.append(sql) + + if not sources: + return result + + sources.append(result) + return sources + + def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> sge.Create: + name = type(udf_node).__name__ + type_mapper = self.udf_type_mapper + + body = PythonToJavaScriptTranslator(udf_node.__func__).compile() + config = udf_node.__config__ + libraries = config.get("libraries", []) + + signature = [ + sge.ColumnDef( + this=sg.to_identifier(name, quoted=self.quoted), + kind=type_mapper.from_ibis(param.annotation.pattern.dtype), + ) + for name, param in udf_node.__signature__.parameters.items() + ] + + lines = ['"""'] + + if config.get("strict", True): + lines.append('"use strict";') + + lines += [ + body, + "", + f"return {udf_node.__func_name__}({', '.join(udf_node.argnames)});", + '"""', + ] + + func = sge.Create( + kind="FUNCTION", + this=sge.UserDefinedFunction( + this=sg.to_identifier(name), expressions=signature, wrapped=True + ), + # not exactly what I had in mind, but it works + # + # quoting is too simplistic to handle multiline strings + expression=sge.Var(this="\n".join(lines)), + exists=False, + properties=sge.Properties( + expressions=[ + sge.TemporaryProperty(), + sge.ReturnsProperty(this=type_mapper.from_ibis(udf_node.dtype)), + sge.StabilityProperty( + this="IMMUTABLE" if config.get("determinism") else "VOLATILE" + ), + sge.LanguageProperty(this=sg.to_identifier("js")), + ] + + [ + sge.Property( + this=sg.to_identifier("library"), value=self.f.array(*libraries) + ) + ] + * bool(libraries) + ), + ) + + return func + + @staticmethod + def _minimize_spec(start, end, spec): + if ( + start is None + and isinstance(getattr(end, "value", None), ops.Literal) + and end.value.value == 0 + and end.following + ): + return None + return spec + + def visit_BoundingBox(self, op, *, arg): + name = type(op).__name__[len("Geo") :].lower() + return sge.Dot( + this=self.f.st_boundingbox(arg), expression=sg.to_identifier(name) + ) + + visit_GeoXMax = visit_GeoXMin = visit_GeoYMax = visit_GeoYMin = visit_BoundingBox + + def visit_GeoSimplify(self, op, *, arg, tolerance, preserve_collapsed): + if ( + not isinstance(op.preserve_collapsed, ops.Literal) + or op.preserve_collapsed.value + ): + raise com.UnsupportedOperationError( + "BigQuery simplify does not support preserving collapsed geometries, " + "pass preserve_collapsed=False" + ) + return self.f.st_simplify(arg, tolerance) + + def visit_ApproxMedian(self, op, *, arg, where): + return self.agg.approx_quantiles(arg, 2, where=where)[self.f.offset(1)] + + def visit_Pi(self, op): + return self.f.acos(-1) + + def visit_E(self, op): + return self.f.exp(1) + + def visit_TimeDelta(self, op, *, left, right, part): + return self.f.time_diff(left, right, part, dialect=self.dialect) + + def visit_DateDelta(self, op, *, left, right, part): + return self.f.date_diff(left, right, part, dialect=self.dialect) + + def visit_TimestampDelta(self, op, *, left, right, part): + left_tz = op.left.dtype.timezone + right_tz = op.right.dtype.timezone + + if left_tz is None and right_tz is None: + return self.f.datetime_diff(left, right, part) + elif left_tz is not None and right_tz is not None: + return self.f.timestamp_diff(left, right, part) + + raise com.UnsupportedOperationError( + "timestamp difference with mixed timezone/timezoneless values is not implemented" + ) + + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): + if where is not None: + arg = self.if_(where, arg, NULL) + + if order_by: + sep = sge.Order(this=sep, expressions=order_by) + + return sge.GroupConcat(this=arg, separator=sep) + + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + if not isinstance(op.quantile, ops.Literal): + raise com.UnsupportedOperationError( + "quantile must be a literal in BigQuery" + ) + + # BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return + # `resolution + 1` quantiles array. To handle this, we compute the + # resolution ourselves then restructure the output array as needed. + # To avoid excessive resolution we arbitrarily cap it at 100,000 - + # since these are approximate quantiles anyway this seems fine. + quantiles = util.promote_list(op.quantile.value) + fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles] + resolution = min(math.lcm(*(den for _, den in fracs)), 100_000) + indices = [(num * resolution) // den for num, den in fracs] + + if where is not None: + arg = self.if_(where, arg, NULL) + + if not op.arg.dtype.is_floating(): + arg = self.cast(arg, dt.float64) + + array = self.f.approx_quantiles( + arg, sge.IgnoreNulls(this=sge.convert(resolution)) + ) + if isinstance(op, ops.ApproxQuantile): + return array[indices[0]] + + if indices == list(range(resolution + 1)): + return array + else: + return sge.Array(expressions=[array[i] for i in indices]) + + visit_ApproxMultiQuantile = visit_ApproxQuantile + + def visit_FloorDivide(self, op, *, left, right): + return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype) + + def visit_Log2(self, op, *, arg): + return self.f.log(arg, 2, dialect=self.dialect) + + def visit_Log(self, op, *, arg, base): + if base is None: + return self.f.ln(arg) + return self.f.log(arg, base, dialect=self.dialect) + + def visit_ArrayRepeat(self, op, *, arg, times): + start = step = 1 + array_length = self.f.array_length(arg) + stop = self.f.greatest(times, 0) * array_length + i = sg.to_identifier("i") + idx = self.f.coalesce( + self.f.nullif(self.f.mod(i, array_length), 0), array_length + ) + series = self.f.generate_array(start, stop, step) + return self.f.array( + sg.select(arg[self.f.safe_ordinal(idx)]).from_(self._unnest(series, as_=i)) + ) + + def visit_NthValue(self, op, *, arg, nth): + if not isinstance(op.nth, ops.Literal): + raise com.UnsupportedOperationError( + f"BigQuery `nth` must be a literal; got {type(op.nth)}" + ) + return self.f.nth_value(arg, nth) + + def visit_StrRight(self, op, *, arg, nchars): + return self.f.substr(arg, -self.f.least(self.f.length(arg), nchars)) + + def visit_StringJoin(self, op, *, arg, sep): + return self.f.array_to_string(self.f.array(*arg), sep) + + def visit_DayOfWeekIndex(self, op, *, arg): + return self.f.mod(self.f.extract(self.v.dayofweek, arg) + 5, 7) + + def visit_DayOfWeekName(self, op, *, arg): + return self.f.initcap(sge.Cast(this=arg, to="STRING FORMAT 'DAY'")) + + def visit_StringToTimestamp(self, op, *, arg, format_str): + if (timezone := op.dtype.timezone) is not None: + return self.f.parse_timestamp(format_str, arg, timezone) + return self.f.parse_datetime(format_str, arg) + + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + if where is not None and include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by bigquery" + ) + out = self.agg.array_agg(arg, where=where, order_by=order_by) + if not include_null: + out = sge.IgnoreNulls(this=out) + return out + + def _neg_idx_to_pos(self, arg, idx): + return self.if_(idx < 0, self.f.array_length(arg) + idx, idx) + + def visit_ArraySlice(self, op, *, arg, start, stop): + index = sg.to_identifier("bq_arr_slice") + cond = [index >= self._neg_idx_to_pos(arg, start)] + + if stop is not None: + cond.append(index < self._neg_idx_to_pos(arg, stop)) + + el = sg.to_identifier("el") + return self.f.array( + sg.select(el).from_(self._unnest(arg, as_=el, offset=index)).where(*cond) + ) + + def visit_ArrayIndex(self, op, *, arg, index): + return arg[self.f.safe_offset(index)] + + def visit_ArrayContains(self, op, *, arg, other): + name = sg.to_identifier(util.gen_name("bq_arr_contains")) + return sge.Exists( + this=sg.select(sge.convert(1)) + .from_(self._unnest(arg, as_=name)) + .where(name.eq(other)) + ) + + def visit_StringContains(self, op, *, haystack, needle): + return self.f.strpos(haystack, needle) > 0 + + def visti_StringFind(self, op, *, arg, substr, start, end): + if start is not None: + raise NotImplementedError( + "`start` not implemented for BigQuery string find" + ) + if end is not None: + raise NotImplementedError("`end` not implemented for BigQuery string find") + return self.f.strpos(arg, substr) + + def visit_TimestampFromYMDHMS( + self, op, *, year, month, day, hours, minutes, seconds + ): + return self.f.anon.DATETIME(year, month, day, hours, minutes, seconds) + + def visit_NonNullLiteral(self, op, *, value, dtype): + if dtype.is_inet() or dtype.is_macaddr(): + return sge.convert(str(value)) + elif dtype.is_timestamp(): + funcname = "DATETIME" if dtype.timezone is None else "TIMESTAMP" + return self.f.anon[funcname](value.isoformat()) + elif dtype.is_date(): + return self.f.date_from_parts(value.year, value.month, value.day) + elif dtype.is_time(): + time = self.f.time_from_parts(value.hour, value.minute, value.second) + if micros := value.microsecond: + # bigquery doesn't support `time(12, 34, 56.789101)`, AKA a + # float seconds specifier, so add any non-zero micros to the + # time value + return sge.TimeAdd( + this=time, expression=sge.convert(micros), unit=self.v.MICROSECOND + ) + return time + elif dtype.is_binary(): + return sge.Cast( + this=sge.convert(value.hex()), + to=sge.DataType(this=sge.DataType.Type.BINARY), + format=sge.convert("HEX"), + ) + elif dtype.is_interval(): + if dtype.unit == IntervalUnit.NANOSECOND: + raise com.UnsupportedOperationError( + "BigQuery does not support nanosecond intervals" + ) + elif dtype.is_uuid(): + return sge.convert(str(value)) + return None + + def visit_IntervalFromInteger(self, op, *, arg, unit): + if unit == IntervalUnit.NANOSECOND: + raise com.UnsupportedOperationError( + "BigQuery does not support nanosecond intervals" + ) + return sge.Interval(this=arg, unit=self.v[unit.singular]) + + def visit_Strftime(self, op, *, arg, format_str): + arg_dtype = op.arg.dtype + if arg_dtype.is_timestamp(): + if (timezone := arg_dtype.timezone) is None: + return self.f.format_datetime(format_str, arg) + else: + return self.f.format_timestamp(format_str, arg, timezone) + elif arg_dtype.is_date(): + return self.f.format_date(format_str, arg) + else: + assert arg_dtype.is_time(), arg_dtype + return self.f.format_time(format_str, arg) + + def visit_IntervalMultiply(self, op, *, left, right): + unit = self.v[op.left.dtype.resolution.upper()] + return sge.Interval(this=self.f.extract(unit, left) * right, unit=unit) + + def visit_TimestampFromUNIX(self, op, *, arg, unit): + unit = op.unit + if unit == TimestampUnit.SECOND: + return self.f.timestamp_seconds(arg) + elif unit == TimestampUnit.MILLISECOND: + return self.f.timestamp_millis(arg) + elif unit == TimestampUnit.MICROSECOND: + return self.f.timestamp_micros(arg) + elif unit == TimestampUnit.NANOSECOND: + return self.f.timestamp_micros( + self.cast(self.f.round(arg / 1_000), dt.int64) + ) + else: + raise com.UnsupportedOperationError(f"Unit not supported: {unit}") + + def visit_Cast(self, op, *, arg, to): + from_ = op.arg.dtype + if from_.is_timestamp() and to.is_integer(): + return self.f.unix_micros(arg) + elif from_.is_integer() and to.is_timestamp(): + return self.f.timestamp_seconds(arg) + elif from_.is_interval() and to.is_integer(): + if from_.unit in { + IntervalUnit.WEEK, + IntervalUnit.QUARTER, + IntervalUnit.NANOSECOND, + }: + raise com.UnsupportedOperationError( + f"BigQuery does not allow extracting date part `{from_.unit}` from intervals" + ) + return self.f.extract(self.v[to.resolution.upper()], arg) + elif from_.is_floating() and to.is_integer(): + return self.cast(self.f.trunc(arg), dt.int64) + return super().visit_Cast(op, arg=arg, to=to) + + def visit_JSONGetItem(self, op, *, arg, index): + return arg[index] + + def visit_UnwrapJSONString(self, op, *, arg): + return self.f.anon["safe.string"](arg) + + def visit_UnwrapJSONInt64(self, op, *, arg): + return self.f.anon["safe.int64"](arg) + + def visit_UnwrapJSONFloat64(self, op, *, arg): + return self.f.anon["safe.float64"](arg) + + def visit_UnwrapJSONBoolean(self, op, *, arg): + return self.f.anon["safe.bool"](arg) + + def visit_ExtractEpochSeconds(self, op, *, arg): + return self.f.unix_seconds(arg) + + def visit_ExtractWeekOfYear(self, op, *, arg): + return self.f.extract(self.v.isoweek, arg) + + def visit_ExtractIsoYear(self, op, *, arg): + return self.f.extract(self.v.isoyear, arg) + + def visit_ExtractMillisecond(self, op, *, arg): + return self.f.extract(self.v.millisecond, arg) + + def visit_ExtractMicrosecond(self, op, *, arg): + return self.f.extract(self.v.microsecond, arg) + + def visit_TimestampTruncate(self, op, *, arg, unit): + if unit == IntervalUnit.NANOSECOND: + raise com.UnsupportedOperationError( + f"BigQuery does not support truncating {op.arg.dtype} values to unit {unit!r}" + ) + elif unit == IntervalUnit.WEEK: + unit = "WEEK(MONDAY)" + else: + unit = unit.name + return self.f.timestamp_trunc(arg, self.v[unit], dialect=self.dialect) + + def visit_DateTruncate(self, op, *, arg, unit): + if unit == DateUnit.WEEK: + unit = "WEEK(MONDAY)" + else: + unit = unit.name + return self.f.date_trunc(arg, self.v[unit], dialect=self.dialect) + + def visit_TimeTruncate(self, op, *, arg, unit): + if unit == TimeUnit.NANOSECOND: + raise com.UnsupportedOperationError( + f"BigQuery does not support truncating {op.arg.dtype} values to unit {unit!r}" + ) + else: + unit = unit.name + return self.f.time_trunc(arg, self.v[unit], dialect=self.dialect) + + def _nullifzero(self, step, zero, step_dtype): + if step_dtype.is_interval(): + return self.if_(step.eq(zero), NULL, step) + return self.f.nullif(step, zero) + + def _zero(self, dtype): + if dtype.is_interval(): + return self.f.make_interval() + return sge.convert(0) + + def _sign(self, value, dtype): + if dtype.is_interval(): + zero = self._zero(dtype) + return sge.Case( + ifs=[ + self.if_(value < zero, -1), + self.if_(value.eq(zero), 0), + self.if_(value > zero, 1), + ], + default=NULL, + ) + return self.f.sign(value) + + def _make_range(self, func, start, stop, step, step_dtype): + step_sign = self._sign(step, step_dtype) + delta_sign = self._sign(stop - start, step_dtype) + zero = self._zero(step_dtype) + nullifzero = self._nullifzero(step, zero, step_dtype) + condition = sg.and_(sg.not_(nullifzero.is_(NULL)), step_sign.eq(delta_sign)) + gen_array = func(start, stop, step) + name = sg.to_identifier(util.gen_name("bq_arr_range")) + inner = ( + sg.select(name) + .from_(self._unnest(gen_array, as_=name)) + .where(name.neq(stop)) + ) + return self.if_(condition, self.f.array(inner), self.f.array()) + + def visit_IntegerRange(self, op, *, start, stop, step): + return self._make_range(self.f.generate_array, start, stop, step, op.step.dtype) + + def visit_TimestampRange(self, op, *, start, stop, step): + if op.start.dtype.timezone is None or op.stop.dtype.timezone is None: + raise com.IbisTypeError( + "Timestamps without timezone values are not supported when generating timestamp ranges" + ) + return self._make_range( + self.f.generate_timestamp_array, start, stop, step, op.step.dtype + ) + + def visit_First(self, op, *, arg, where, order_by, include_null): + if where is not None: + arg = self.if_(where, arg, NULL) + if include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by bigquery" + ) + + if order_by: + arg = sge.Order(this=arg, expressions=order_by) + + if not include_null: + arg = sge.IgnoreNulls(this=arg) + + array = self.f.array_agg(sge.Limit(this=arg, expression=sge.convert(1))) + return array[self.f.safe_offset(0)] + + def visit_Last(self, op, *, arg, where, order_by, include_null): + if where is not None: + arg = self.if_(where, arg, NULL) + if include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by bigquery" + ) + + if order_by: + arg = sge.Order(this=arg, expressions=order_by) + + if not include_null: + arg = sge.IgnoreNulls(this=arg) + + array = self.f.array_reverse(self.f.array_agg(arg)) + return array[self.f.safe_offset(0)] + + def visit_ArrayFilter(self, op, *, arg, body, param): + return self.f.array( + sg.select(param).from_(self._unnest(arg, as_=param)).where(body) + ) + + def visit_ArrayMap(self, op, *, arg, body, param): + return self.f.array(sg.select(body).from_(self._unnest(arg, as_=param))) + + def visit_ArrayZip(self, op, *, arg): + lengths = [self.f.array_length(arr) - 1 for arr in arg] + idx = sg.to_identifier(util.gen_name("bq_arr_idx")) + indices = self._unnest( + self.f.generate_array(0, self.f.greatest(*lengths)), as_=idx + ) + struct_fields = [ + arr[self.f.safe_offset(idx)].as_(name) + for name, arr in zip(op.dtype.value_type.names, arg) + ] + return self.f.array( + sge.Select(kind="STRUCT", expressions=struct_fields).from_(indices) + ) + + def visit_ArrayPosition(self, op, *, arg, other): + name = sg.to_identifier(util.gen_name("bq_arr")) + idx = sg.to_identifier(util.gen_name("bq_arr_idx")) + unnest = self._unnest(arg, as_=name, offset=idx) + return self.f.coalesce( + sg.select(idx + 1).from_(unnest).where(name.eq(other)).limit(1).subquery(), + 0, + ) + + def _unnest(self, expression, *, as_, offset=None): + alias = sge.TableAlias(columns=[sg.to_identifier(as_)]) + return sge.Unnest(expressions=[expression], alias=alias, offset=offset) + + def visit_ArrayRemove(self, op, *, arg, other): + name = sg.to_identifier(util.gen_name("bq_arr")) + unnest = self._unnest(arg, as_=name) + both_null = sg.and_(name.is_(NULL), other.is_(NULL)) + cond = sg.or_(name.neq(other), both_null) + return self.f.array(sg.select(name).from_(unnest).where(cond)) + + def visit_ArrayDistinct(self, op, *, arg): + name = util.gen_name("bq_arr") + return self.f.array( + sg.select(name).distinct().from_(self._unnest(arg, as_=name)) + ) + + def visit_ArraySort(self, op, *, arg): + name = util.gen_name("bq_arr") + return self.f.array( + sg.select(name).from_(self._unnest(arg, as_=name)).order_by(name) + ) + + def visit_ArrayUnion(self, op, *, left, right): + lname = util.gen_name("bq_arr_left") + rname = util.gen_name("bq_arr_right") + lhs = sg.select(lname).from_(self._unnest(left, as_=lname)) + rhs = sg.select(rname).from_(self._unnest(right, as_=rname)) + return self.f.array(sg.union(lhs, rhs, distinct=True)) + + def visit_ArrayIntersect(self, op, *, left, right): + lname = util.gen_name("bq_arr_left") + rname = util.gen_name("bq_arr_right") + lhs = sg.select(lname).from_(self._unnest(left, as_=lname)) + rhs = sg.select(rname).from_(self._unnest(right, as_=rname)) + return self.f.array(sg.intersect(lhs, rhs, distinct=True)) + + def visit_RegexExtract(self, op, *, arg, pattern, index): + matches = self.f.regexp_contains(arg, pattern) + nonzero_index_replace = self.f.regexp_replace( + arg, + self.f.concat(".*?", pattern, ".*"), + self.f.concat("\\", self.cast(index, dt.string)), + ) + zero_index_replace = self.f.regexp_replace( + arg, self.f.concat(".*?", self.f.concat("(", pattern, ")"), ".*"), "\\1" + ) + extract = self.if_(index.eq(0), zero_index_replace, nonzero_index_replace) + return self.if_(matches, extract, NULL) + + def visit_TimestampAddSub(self, op, *, left, right): + if not isinstance(right, sge.Interval): + raise com.OperationNotDefinedError( + "BigQuery does not support non-literals on the right side of timestamp add/subtract" + ) + if (unit := op.right.dtype.unit) == IntervalUnit.NANOSECOND: + raise com.UnsupportedOperationError( + f"BigQuery does not allow binary operation {type(op).__name__} with " + f"INTERVAL offset {unit}" + ) + + opname = type(op).__name__[len("Timestamp") :] + funcname = f"TIMESTAMP_{opname.upper()}" + return self.f.anon[funcname](left, right) + + visit_TimestampAdd = visit_TimestampSub = visit_TimestampAddSub + + def visit_DateAddSub(self, op, *, left, right): + if not isinstance(right, sge.Interval): + raise com.OperationNotDefinedError( + "BigQuery does not support non-literals on the right side of date add/subtract" + ) + if not (unit := op.right.dtype.unit).is_date(): + raise com.UnsupportedOperationError( + f"BigQuery does not allow binary operation {type(op).__name__} with " + f"INTERVAL offset {unit}" + ) + opname = type(op).__name__[len("Date") :] + funcname = f"DATE_{opname.upper()}" + return self.f.anon[funcname](left, right) + + visit_DateAdd = visit_DateSub = visit_DateAddSub + + def visit_Covariance(self, op, *, left, right, how, where): + if where is not None: + left = self.if_(where, left, NULL) + right = self.if_(where, right, NULL) + + if op.left.dtype.is_boolean(): + left = self.cast(left, dt.int64) + + if op.right.dtype.is_boolean(): + right = self.cast(right, dt.int64) + + how = op.how[:4].upper() + assert how in ("POP", "SAMP"), 'how not in ("POP", "SAMP")' + return self.agg[f"COVAR_{how}"](left, right, where=where) + + def visit_Correlation(self, op, *, left, right, how, where): + if how == "sample": + raise ValueError(f"Correlation with how={how!r} is not supported.") + + if where is not None: + left = self.if_(where, left, NULL) + right = self.if_(where, right, NULL) + + if op.left.dtype.is_boolean(): + left = self.cast(left, dt.int64) + + if op.right.dtype.is_boolean(): + right = self.cast(right, dt.int64) + + return self.agg.corr(left, right, where=where) + + def visit_TypeOf(self, op, *, arg): + return self._pudf("typeof", arg) + + def visit_Xor(self, op, *, left, right): + return sg.or_(sg.and_(left, sg.not_(right)), sg.and_(sg.not_(left), right)) + + def visit_HashBytes(self, op, *, arg, how): + if how not in ("md5", "sha1", "sha256", "sha512"): + raise NotImplementedError(how) + return self.f[how](arg) + + @staticmethod + def _gen_valid_name(name: str) -> str: + candidate = "_".join(map(str.strip, _NAME_REGEX.findall(name))) or "tmp" + # column names cannot be longer than 300 characters + # + # https://cloud.google.com/bigquery/docs/schemas#column_names + # + # it's easy to rename columns, so raise an exception telling the user + # to do so + # + # we could potentially relax this and support arbitrary-length columns + # by compressing the information using hashing, but there's no reason + # to solve that problem until someone encounters this error and cannot + # rename their columns + limit = 300 + if len(candidate) > limit: + raise com.IbisError( + f"BigQuery does not allow column names longer than {limit:d} characters. " + "Please rename your columns to have fewer characters." + ) + return candidate + + def visit_CountStar(self, op, *, arg, where): + if where is not None: + return self.f.countif(where) + return self.f.count(STAR) + + def visit_CountDistinctStar(self, op, *, where, arg): + # Bigquery does not support count(distinct a,b,c) or count(distinct (a, b, c)) + # as expressions must be "groupable": + # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#group_by_grouping_item + # + # Instead, convert the entire expression to a string + # SELECT COUNT(DISTINCT concat(to_json_string(a), to_json_string(b))) + # This works with an array of datatypes which generates a unique string + # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_encodings + row = sge.Concat( + expressions=[ + self.f.to_json_string(sg.column(x, quoted=self.quoted)) + for x in op.arg.schema.keys() + ] + ) + if where is not None: + row = self.if_(where, row, NULL) + return self.f.count(sge.Distinct(expressions=[row])) + + def visit_Degrees(self, op, *, arg): + return self._pudf("degrees", arg) + + def visit_Radians(self, op, *, arg): + return self._pudf("radians", arg) + + def visit_CountDistinct(self, op, *, arg, where): + if where is not None: + arg = self.if_(where, arg, NULL) + return self.f.count(sge.Distinct(expressions=[arg])) + + def visit_RandomUUID(self, op, **kwargs): + return self.f.generate_uuid() + + def visit_ExtractFile(self, op, *, arg): + return self._pudf("cw_url_extract_file", arg) + + def visit_ExtractFragment(self, op, *, arg): + return self._pudf("cw_url_extract_fragment", arg) + + def visit_ExtractPath(self, op, *, arg): + return self._pudf("cw_url_extract_path", arg) + + def visit_ExtractProtocol(self, op, *, arg): + return self._pudf("cw_url_extract_protocol", arg) + + def visit_ExtractQuery(self, op, *, arg, key): + if key is not None: + return self._pudf("cw_url_extract_parameter", arg, key) + else: + return self._pudf("cw_url_extract_query", arg) + + def _pudf(self, name, *args): + name = sg.table(name, db="persistent_udfs", catalog="bigquery-public-data").sql( + self.dialect + ) + return self.f[name](*args) + + def visit_DropColumns(self, op, *, parent, columns_to_drop): + quoted = self.quoted + excludes = [sg.column(column, quoted=quoted) for column in columns_to_drop] + star = sge.Star(**{"except": excludes}) + table = sg.to_identifier(parent.alias_or_name, quoted=quoted) + column = sge.Column(this=star, table=table) + return sg.select(column).from_(parent) + + def visit_TableUnnest( + self, op, *, parent, column, offset: str | None, keep_empty: bool + ): + quoted = self.quoted + + column_alias = sg.to_identifier( + util.gen_name("table_unnest_column"), quoted=quoted + ) + + selcols = [] + + table = sg.to_identifier(parent.alias_or_name, quoted=quoted) + + opname = op.column.name + overlaps_with_parent = opname in op.parent.schema + computed_column = column_alias.as_(opname, quoted=quoted) + + # replace the existing column if the unnested column hasn't been + # renamed + # + # e.g., table.unnest("x") + if overlaps_with_parent: + selcols.append( + sge.Column(this=sge.Star(replace=[computed_column]), table=table) + ) + else: + selcols.append(sge.Column(this=STAR, table=table)) + selcols.append(computed_column) + + if offset is not None: + offset = sg.to_identifier(offset, quoted=quoted) + selcols.append(offset) + + unnest = sge.Unnest( + expressions=[column], + alias=sge.TableAlias(columns=[column_alias]), + offset=offset, + ) + return ( + sg.select(*selcols) + .from_(parent) + .join(unnest, join_type="CROSS" if not keep_empty else "LEFT") + ) + + def visit_TimestampBucket(self, op, *, arg, interval, offset): + arg_dtype = op.arg.dtype + if arg_dtype.timezone is not None: + funcname = "timestamp" + else: + funcname = "datetime" + + func = self.f[f"{funcname}_bucket"] + + origin = sge.convert("1970-01-01") + if offset is not None: + origin = self.f.anon[f"{funcname}_add"](origin, offset) + + return func(arg, interval, origin) + + def _array_reduction(self, *, arg, reduction): + name = sg.to_identifier(util.gen_name(f"bq_arr_{reduction}")) + return ( + sg.select(self.f[reduction](name)) + .from_(self._unnest(arg, as_=name)) + .subquery() + ) + + def visit_ArrayMin(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="min") + + def visit_ArrayMax(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="max") + + def visit_ArraySum(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="sum") + + def visit_ArrayMean(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="avg") + + def visit_ArrayAny(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="logical_or") + + def visit_ArrayAll(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="logical_and") + + +compiler = BigQueryCompiler() diff --git a/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py b/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py new file mode 100644 index 0000000000..1f67902395 --- /dev/null +++ b/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py @@ -0,0 +1,367 @@ +# Contains code from https://github.com/ibis-project/ibis/blob/main/ibis/backends/sql/rewrites.py + +"""Some common rewrite functions to be shared between backends.""" + +from __future__ import annotations + +from collections import defaultdict + +from ibis.common.collections import FrozenDict # noqa: TCH001 +from ibis.common.deferred import _, deferred, Item, var +from ibis.common.exceptions import ExpressionError, IbisInputError +from ibis.common.graph import Node as Traversable +from ibis.common.graph import traverse +from ibis.common.grounds import Concrete +from ibis.common.patterns import Check, pattern, replace +from ibis.common.typing import VarTuple # noqa: TCH001 +import ibis.expr.operations as ops +from ibis.util import Namespace, promote_list +import toolz + +p = Namespace(pattern, module=ops) +d = Namespace(deferred, module=ops) + + +x = var("x") +y = var("y") +name = var("name") + + +class DerefMap(Concrete, Traversable): + """Trace and replace fields from earlier relations in the hierarchy. + In order to provide a nice user experience, we need to allow expressions + from earlier relations in the hierarchy. Consider the following example: + t = ibis.table([('a', 'int64'), ('b', 'string')], name='t') + t1 = t.select([t.a, t.b]) + t2 = t1.filter(t.a > 0) # note that not t1.a is referenced here + t3 = t2.select(t.a) # note that not t2.a is referenced here + However the relational operations in the IR are strictly enforcing that + the expressions are referencing the immediate parent only. So we need to + track fields upwards the hierarchy to replace `t.a` with `t1.a` and `t2.a` + in the example above. This is called dereferencing. + Whether we can treat or not a field of a relation semantically equivalent + with a field of an earlier relation in the hierarchy depends on the + `.values` mapping of the relation. Leaf relations, like `t` in the example + above, have an empty `.values` mapping, so we cannot dereference fields + from them. On the other hand a projection, like `t1` in the example above, + has a `.values` mapping like `{'a': t.a, 'b': t.b}`, so we can deduce that + `t1.a` is semantically equivalent with `t.a` and so on. + """ + + """The relations we want the values to point to.""" + rels: VarTuple[ops.Relation] + + """Substitution mapping from values of earlier relations to the fields of `rels`.""" + subs: FrozenDict[ops.Value, ops.Field] + + """Ambiguous field references.""" + ambigs: FrozenDict[ops.Value, VarTuple[ops.Value]] + + @classmethod + def from_targets(cls, rels, extra=None): + """Create a dereference map from a list of target relations. + Usually a single relation is passed except for joins where multiple + relations are involved. + Parameters + ---------- + rels : list of ops.Relation + The target relations to dereference to. + extra : dict, optional + Extra substitutions to be added to the dereference map. + Returns + ------- + DerefMap + """ + rels = promote_list(rels) + mapping = defaultdict(dict) + for rel in rels: + for field in rel.fields.values(): + for value, distance in cls.backtrack(field): + mapping[value][field] = distance + + subs, ambigs = {}, {} + for from_, to in mapping.items(): + mindist = min(to.values()) + minkeys = [k for k, v in to.items() if v == mindist] + # if all the closest fields are from the same relation, then we + # can safely substitute them and we pick the first one arbitrarily + if all(minkeys[0].relations == k.relations for k in minkeys): + subs[from_] = minkeys[0] + else: + ambigs[from_] = minkeys + + if extra is not None: + subs.update(extra) + + return cls(rels, subs, ambigs) + + @classmethod + def backtrack(cls, value): + """Backtrack the field in the relation hierarchy. + The field is traced back until no modification is made, so only follow + ops.Field nodes not arbitrary values. + Parameters + ---------- + value : ops.Value + The value to backtrack. + Yields + ------ + tuple[ops.Field, int] + The value node and the distance from the original value. + """ + distance = 0 + # track down the field in the hierarchy until no modification + # is made so only follow ops.Field nodes not arbitrary values; + while isinstance(value, ops.Field): + yield value, distance + value = value.rel.values.get(value.name) + distance += 1 + if ( + value is not None + and value.relations + and not value.find(ops.Impure, filter=ops.Value) + ): + yield value, distance + + def dereference(self, value): + """Dereference a value to the target relations. + Also check for ambiguous field references. If a field reference is found + which is marked as ambiguous, then raise an error. + Parameters + ---------- + value : ops.Value + The value to dereference. + Returns + ------- + ops.Value + The dereferenced value. + """ + ambigs = value.find(lambda x: x in self.ambigs, filter=ops.Value) + if ambigs: + raise IbisInputError( + f"Ambiguous field reference {ambigs!r} in expression {value!r}" + ) + return value.replace(self.subs, filter=ops.Value) + + +def flatten_predicates(node): + """Yield the expressions corresponding to the `And` nodes of a predicate. + Examples + -------- + >>> import ibis + >>> t = ibis.table([("a", "int64"), ("b", "string")], name="t") + >>> filt = (t.a == 1) & (t.b == "foo") + >>> predicates = flatten_predicates(filt.op()) + >>> len(predicates) + 2 + >>> predicates[0].to_expr().name("left") + r0 := UnboundTable: t + a int64 + b string + left: r0.a == 1 + >>> predicates[1].to_expr().name("right") + r0 := UnboundTable: t + a int64 + b string + right: r0.b == 'foo' + """ + + def predicate(node): + if isinstance(node, ops.And): + # proceed and don't yield the node + return True, None + else: + # halt and yield the node + return False, node + + return list(traverse(predicate, node)) + + +@replace(p.Field(p.JoinChain)) +def peel_join_field(_): + return _.rel.values[_.name] + + +@replace(p.ScalarParameter) +def replace_parameter(_, params, **kwargs): + """Replace scalar parameters with their values.""" + return ops.Literal(value=params[_], dtype=_.dtype) + + +@replace(p.StringSlice) +def lower_stringslice(_, **kwargs): + """Rewrite StringSlice in terms of Substring.""" + if _.end is None: + return ops.Substring(_.arg, start=_.start) + if _.start is None: + return ops.Substring(_.arg, start=0, length=_.end) + if ( + isinstance(_.start, ops.Literal) + and isinstance(_.start.value, int) + and isinstance(_.end, ops.Literal) + and isinstance(_.end.value, int) + ): + # optimization for constant values + length = _.end.value - _.start.value + else: + length = ops.Subtract(_.end, _.start) + return ops.Substring(_.arg, start=_.start, length=length) + + +@replace(p.Analytic) +def wrap_analytic(_, **__): + # Wrap analytic functions in a window function + return ops.WindowFunction(_) + + +@replace(p.Reduction) +def project_wrap_reduction(_, rel): + # Query all the tables that the reduction depends on + if _.relations == {rel}: + # The reduction is fully originating from the `rel`, so turn + # it into a window function of `rel` + return ops.WindowFunction(_) + else: + # 1. The reduction doesn't depend on any table, constructed from + # scalar values, so turn it into a scalar subquery. + # 2. The reduction is originating from `rel` and other tables, + # so this is a correlated scalar subquery. + # 3. The reduction is originating entirely from other tables, + # so this is an uncorrelated scalar subquery. + return ops.ScalarSubquery(_.to_expr().as_table()) + + +def rewrite_project_input(value, relation): + # we need to detect reductions which are either turned into window functions + # or scalar subqueries depending on whether they are originating from the + # relation + return value.replace( + wrap_analytic | project_wrap_reduction, + filter=p.Value & ~p.WindowFunction, + context={"rel": relation}, + ) + + +ReductionLike = p.Reduction | p.Field(p.Aggregate(groups={})) + + +@replace(ReductionLike) +def filter_wrap_reduction(_): + # Wrap reductions or fields referencing an aggregation without a group by - + # which are scalar fields - in a scalar subquery. In the latter case we + # use the reduction value from the aggregation. + if isinstance(_, ops.Field): + value = _.rel.values[_.name] + else: + value = _ + return ops.ScalarSubquery(value.to_expr().as_table()) + + +def rewrite_filter_input(value): + return value.replace( + wrap_analytic | filter_wrap_reduction, filter=p.Value & ~p.WindowFunction + ) + + +@replace(p.Analytic | p.Reduction) +def window_wrap_reduction(_, window): + # Wrap analytic and reduction functions in a window function. Used in the + # value.over() API. + return ops.WindowFunction( + _, + how=window.how, + start=window.start, + end=window.end, + group_by=window.groupings, + order_by=window.orderings, + ) + + +@replace(p.WindowFunction) +def window_merge_frames(_, window): + # Merge window frames, used in the value.over() and groupby.select() APIs. + if _.how != window.how: + raise ExpressionError( + f"Unable to merge {_.how} window with {window.how} window" + ) + elif _.start and window.start and _.start != window.start: + raise ExpressionError( + "Unable to merge windows with conflicting `start` boundary" + ) + elif _.end and window.end and _.end != window.end: + raise ExpressionError("Unable to merge windows with conflicting `end` boundary") + + start = _.start or window.start + end = _.end or window.end + group_by = tuple(toolz.unique(_.group_by + window.groupings)) + + order_keys = {} + for sort_key in window.orderings + _.order_by: + order_keys[sort_key.expr] = sort_key.ascending, sort_key.nulls_first + + order_by = ( + ops.SortKey(expr, ascending=ascending, nulls_first=nulls_first) + for expr, (ascending, nulls_first) in order_keys.items() + ) + return _.copy(start=start, end=end, group_by=group_by, order_by=order_by) + + +def rewrite_window_input(value, window): + context = {"window": window} + # if self is a reduction or analytic function, wrap it in a window function + node = value.replace( + window_wrap_reduction, + filter=p.Value & ~p.WindowFunction, + context=context, + ) + # if self is already a window function, merge the existing window frame + # with the requested window frame + return node.replace(window_merge_frames, filter=p.Value, context=context) + + +# TODO(kszucs): schema comparison should be updated to not distinguish between +# different column order +@replace(p.Project(y @ p.Relation) & Check(_.schema == y.schema)) +def complete_reprojection(_, y): + # TODO(kszucs): this could be moved to the pattern itself but not sure how + # to express it, especially in a shorter way then the following check + for name in _.schema: + if _.values[name] != ops.Field(y, name): + return _ + return y + + +@replace(p.Project(y @ p.Project)) +def subsequent_projects(_, y): + rule = p.Field(y, name) >> Item(y.values, name) + values = {k: v.replace(rule, filter=ops.Value) for k, v in _.values.items()} + return ops.Project(y.parent, values) + + +@replace(p.Filter(y @ p.Filter)) +def subsequent_filters(_, y): + rule = p.Field(y, name) >> d.Field(y.parent, name) + preds = tuple(v.replace(rule, filter=ops.Value) for v in _.predicates) + return ops.Filter(y.parent, y.predicates + preds) + + +@replace(p.Filter(y @ p.Project)) +def reorder_filter_project(_, y): + rule = p.Field(y, name) >> Item(y.values, name) + preds = tuple(v.replace(rule, filter=ops.Value) for v in _.predicates) + + inner = ops.Filter(y.parent, preds) + rule = p.Field(y.parent, name) >> d.Field(inner, name) + projs = {k: v.replace(rule, filter=ops.Value) for k, v in y.values.items()} + + return ops.Project(inner, projs) + + +def simplify(node): + # TODO(kszucs): add a utility to the graph module to do rewrites in multiple + # passes after each other + node = node.replace(reorder_filter_project) + node = node.replace(reorder_filter_project) + node = node.replace(subsequent_projects | subsequent_filters) + node = node.replace(complete_reprojection) + return node diff --git a/third_party/bigframes_vendored/ibis/expr/rewrites.py b/third_party/bigframes_vendored/ibis/expr/rewrites.py new file mode 100644 index 0000000000..0583d2b87e --- /dev/null +++ b/third_party/bigframes_vendored/ibis/expr/rewrites.py @@ -0,0 +1,380 @@ +# Contains code from https://github.com/ibis-project/ibis/blob/main/ibis/expr/rewrites.py + +"""Some common rewrite functions to be shared between backends.""" + +from __future__ import annotations + +from collections import defaultdict + +from ibis.common.collections import FrozenDict # noqa: TCH001 +from ibis.common.deferred import _, deferred, Item, var +from ibis.common.exceptions import ExpressionError, IbisInputError +from ibis.common.graph import Node as Traversable +from ibis.common.graph import traverse +from ibis.common.grounds import Concrete +from ibis.common.patterns import Check, pattern, replace +from ibis.common.typing import VarTuple # noqa: TCH001 +import ibis.expr.operations as ops +from ibis.util import Namespace, promote_list +import toolz + +p = Namespace(pattern, module=ops) +d = Namespace(deferred, module=ops) + + +x = var("x") +y = var("y") +name = var("name") + + +class DerefMap(Concrete, Traversable): + """Trace and replace fields from earlier relations in the hierarchy. + + In order to provide a nice user experience, we need to allow expressions + from earlier relations in the hierarchy. Consider the following example: + + t = ibis.table([('a', 'int64'), ('b', 'string')], name='t') + t1 = t.select([t.a, t.b]) + t2 = t1.filter(t.a > 0) # note that not t1.a is referenced here + t3 = t2.select(t.a) # note that not t2.a is referenced here + + However the relational operations in the IR are strictly enforcing that + the expressions are referencing the immediate parent only. So we need to + track fields upwards the hierarchy to replace `t.a` with `t1.a` and `t2.a` + in the example above. This is called dereferencing. + + Whether we can treat or not a field of a relation semantically equivalent + with a field of an earlier relation in the hierarchy depends on the + `.values` mapping of the relation. Leaf relations, like `t` in the example + above, have an empty `.values` mapping, so we cannot dereference fields + from them. On the other hand a projection, like `t1` in the example above, + has a `.values` mapping like `{'a': t.a, 'b': t.b}`, so we can deduce that + `t1.a` is semantically equivalent with `t.a` and so on. + """ + + """The relations we want the values to point to.""" + rels: VarTuple[ops.Relation] + + """Substitution mapping from values of earlier relations to the fields of `rels`.""" + subs: FrozenDict[ops.Value, ops.Field] + + """Ambiguous field references.""" + ambigs: FrozenDict[ops.Value, VarTuple[ops.Value]] + + @classmethod + def from_targets(cls, rels, extra=None): + """Create a dereference map from a list of target relations. + + Usually a single relation is passed except for joins where multiple + relations are involved. + + Parameters + ---------- + rels : list of ops.Relation + The target relations to dereference to. + extra : dict, optional + Extra substitutions to be added to the dereference map. + + Returns + ------- + DerefMap + """ + rels = promote_list(rels) + mapping = defaultdict(dict) + for rel in rels: + for field in rel.fields.values(): + for value, distance in cls.backtrack(field): + mapping[value][field] = distance + + subs, ambigs = {}, {} + for from_, to in mapping.items(): + mindist = min(to.values()) + minkeys = [k for k, v in to.items() if v == mindist] + # if all the closest fields are from the same relation, then we + # can safely substitute them and we pick the first one arbitrarily + if all(minkeys[0].relations == k.relations for k in minkeys): + subs[from_] = minkeys[0] + else: + ambigs[from_] = minkeys + + if extra is not None: + subs.update(extra) + + return cls(rels, subs, ambigs) + + @classmethod + def backtrack(cls, value): + """Backtrack the field in the relation hierarchy. + + The field is traced back until no modification is made, so only follow + ops.Field nodes not arbitrary values. + + Parameters + ---------- + value : ops.Value + The value to backtrack. + + Yields + ------ + tuple[ops.Field, int] + The value node and the distance from the original value. + """ + distance = 0 + # track down the field in the hierarchy until no modification + # is made so only follow ops.Field nodes not arbitrary values; + while isinstance(value, ops.Field): + yield value, distance + value = value.rel.values.get(value.name) + distance += 1 + if ( + value is not None + and value.relations + and not value.find(ops.Impure, filter=ops.Value) + ): + yield value, distance + + def dereference(self, value): + """Dereference a value to the target relations. + + Also check for ambiguous field references. If a field reference is found + which is marked as ambiguous, then raise an error. + + Parameters + ---------- + value : ops.Value + The value to dereference. + + Returns + ------- + ops.Value + The dereferenced value. + """ + ambigs = value.find(lambda x: x in self.ambigs, filter=ops.Value) + if ambigs: + raise IbisInputError( + f"Ambiguous field reference {ambigs!r} in expression {value!r}" + ) + return value.replace(self.subs, filter=ops.Value) + + +def flatten_predicates(node): + """Yield the expressions corresponding to the `And` nodes of a predicate. + + Examples + -------- + >>> import ibis + >>> t = ibis.table([("a", "int64"), ("b", "string")], name="t") + >>> filt = (t.a == 1) & (t.b == "foo") + >>> predicates = flatten_predicates(filt.op()) + >>> len(predicates) + 2 + >>> predicates[0].to_expr().name("left") + r0 := UnboundTable: t + a int64 + b string + left: r0.a == 1 + >>> predicates[1].to_expr().name("right") + r0 := UnboundTable: t + a int64 + b string + right: r0.b == 'foo' + + """ + + def predicate(node): + if isinstance(node, ops.And): + # proceed and don't yield the node + return True, None + else: + # halt and yield the node + return False, node + + return list(traverse(predicate, node)) + + +@replace(p.Field(p.JoinChain)) +def peel_join_field(_): + return _.rel.values[_.name] + + +@replace(p.ScalarParameter) +def replace_parameter(_, params, **kwargs): + """Replace scalar parameters with their values.""" + return ops.Literal(value=params[_], dtype=_.dtype) + + +@replace(p.StringSlice) +def lower_stringslice(_, **kwargs): + """Rewrite StringSlice in terms of Substring.""" + if _.end is None: + return ops.Substring(_.arg, start=_.start) + if _.start is None: + return ops.Substring(_.arg, start=0, length=_.end) + if ( + isinstance(_.start, ops.Literal) + and isinstance(_.start.value, int) + and isinstance(_.end, ops.Literal) + and isinstance(_.end.value, int) + ): + # optimization for constant values + length = _.end.value - _.start.value + else: + length = ops.Subtract(_.end, _.start) + return ops.Substring(_.arg, start=_.start, length=length) + + +@replace(p.Analytic) +def project_wrap_analytic(_, rel): + # Wrap analytic functions in a window function + return ops.WindowFunction(_) + + +@replace(p.Reduction) +def project_wrap_reduction(_, rel): + # Query all the tables that the reduction depends on + if _.relations == {rel}: + # The reduction is fully originating from the `rel`, so turn + # it into a window function of `rel` + return ops.WindowFunction(_) + else: + # 1. The reduction doesn't depend on any table, constructed from + # scalar values, so turn it into a scalar subquery. + # 2. The reduction is originating from `rel` and other tables, + # so this is a correlated scalar subquery. + # 3. The reduction is originating entirely from other tables, + # so this is an uncorrelated scalar subquery. + return ops.ScalarSubquery(_.to_expr().as_table()) + + +def rewrite_project_input(value, relation): + # we need to detect reductions which are either turned into window functions + # or scalar subqueries depending on whether they are originating from the + # relation + return value.replace( + project_wrap_analytic | project_wrap_reduction, + filter=p.Value & ~p.WindowFunction, + context={"rel": relation}, + ) + + +ReductionLike = p.Reduction | p.Field(p.Aggregate(groups={})) + + +@replace(ReductionLike) +def filter_wrap_reduction(_): + # Wrap reductions or fields referencing an aggregation without a group by - + # which are scalar fields - in a scalar subquery. In the latter case we + # use the reduction value from the aggregation. + if isinstance(_, ops.Field): + value = _.rel.values[_.name] + else: + value = _ + return ops.ScalarSubquery(value.to_expr().as_table()) + + +def rewrite_filter_input(value): + return value.replace(filter_wrap_reduction, filter=p.Value & ~p.WindowFunction) + + +@replace(p.Analytic | p.Reduction) +def window_wrap_reduction(_, window): + # Wrap analytic and reduction functions in a window function. Used in the + # value.over() API. + return ops.WindowFunction( + _, + how=window.how, + start=window.start, + end=window.end, + group_by=window.groupings, + order_by=window.orderings, + ) + + +@replace(p.WindowFunction) +def window_merge_frames(_, window): + # Merge window frames, used in the value.over() and groupby.select() APIs. + if _.how != window.how: + raise ExpressionError( + f"Unable to merge {_.how} window with {window.how} window" + ) + elif _.start and window.start and _.start != window.start: + raise ExpressionError( + "Unable to merge windows with conflicting `start` boundary" + ) + elif _.end and window.end and _.end != window.end: + raise ExpressionError("Unable to merge windows with conflicting `end` boundary") + + start = _.start or window.start + end = _.end or window.end + group_by = tuple(toolz.unique(_.group_by + window.groupings)) + + order_keys = {} + for sort_key in window.orderings + _.order_by: + order_keys[sort_key.expr] = sort_key.ascending, sort_key.nulls_first + + order_by = ( + ops.SortKey(expr, ascending=ascending, nulls_first=nulls_first) + for expr, (ascending, nulls_first) in order_keys.items() + ) + return _.copy(start=start, end=end, group_by=group_by, order_by=order_by) + + +def rewrite_window_input(value, window): + context = {"window": window} + # if self is a reduction or analytic function, wrap it in a window function + node = value.replace( + window_wrap_reduction, + filter=p.Value & ~p.WindowFunction, + context=context, + ) + # if self is already a window function, merge the existing window frame + # with the requested window frame + return node.replace(window_merge_frames, filter=p.Value, context=context) + + +# TODO(kszucs): schema comparison should be updated to not distinguish between +# different column order +@replace(p.Project(y @ p.Relation) & Check(_.schema == y.schema)) +def complete_reprojection(_, y): + # TODO(kszucs): this could be moved to the pattern itself but not sure how + # to express it, especially in a shorter way then the following check + for name in _.schema: + if _.values[name] != ops.Field(y, name): + return _ + return y + + +@replace(p.Project(y @ p.Project)) +def subsequent_projects(_, y): + rule = p.Field(y, name) >> Item(y.values, name) + values = {k: v.replace(rule, filter=ops.Value) for k, v in _.values.items()} + return ops.Project(y.parent, values) + + +@replace(p.Filter(y @ p.Filter)) +def subsequent_filters(_, y): + rule = p.Field(y, name) >> d.Field(y.parent, name) + preds = tuple(v.replace(rule, filter=ops.Value) for v in _.predicates) + return ops.Filter(y.parent, y.predicates + preds) + + +@replace(p.Filter(y @ p.Project)) +def reorder_filter_project(_, y): + rule = p.Field(y, name) >> Item(y.values, name) + preds = tuple(v.replace(rule, filter=ops.Value) for v in _.predicates) + + inner = ops.Filter(y.parent, preds) + rule = p.Field(y.parent, name) >> d.Field(inner, name) + projs = {k: v.replace(rule, filter=ops.Value) for k, v in y.values.items()} + + return ops.Project(inner, projs) + + +def simplify(node): + # TODO(kszucs): add a utility to the graph module to do rewrites in multiple + # passes after each other + node = node.replace(reorder_filter_project) + node = node.replace(reorder_filter_project) + node = node.replace(subsequent_projects | subsequent_filters) + node = node.replace(complete_reprojection) + return node diff --git a/third_party/bigframes_vendored/pandas/core/arrays/arrow/accessors.py b/third_party/bigframes_vendored/pandas/core/arrays/arrow/accessors.py index ab199d53bd..771146250a 100644 --- a/third_party/bigframes_vendored/pandas/core/arrays/arrow/accessors.py +++ b/third_party/bigframes_vendored/pandas/core/arrays/arrow/accessors.py @@ -6,6 +6,71 @@ from bigframes import constants +class ListAccessor: + """Accessor object for list data properties of the Series values.""" + + def len(self): + """Compute the length of each list in the Series. + + **See Also:** + + - :func:`StringMethods.len` : Compute the length of each element in the Series/Index. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import pyarrow as pa + >>> bpd.options.display.progress_bar = None + >>> s = bpd.Series( + ... [ + ... [1, 2, 3], + ... [3], + ... ], + ... dtype=bpd.ArrowDtype(pa.list_(pa.int64())), + ... ) + >>> s.list.len() + 0 3 + 1 1 + dtype: Int64 + + Returns: + bigframes.series.Series: A Series or Index of integer values indicating + the length of each element in the Series or Index. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + + def __getitem__(self, key: int | slice): + """Index or slice lists in the Series. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import pyarrow as pa + >>> bpd.options.display.progress_bar = None + >>> s = bpd.Series( + ... [ + ... [1, 2, 3], + ... [3], + ... ], + ... dtype=bpd.ArrowDtype(pa.list_(pa.int64())), + ... ) + >>> s.list[0] + 0 1 + 1 3 + dtype: Int64 + + Args: + key (int | slice): Index or slice of indices to access from each list. + For integer indices, only non-negative values are accepted. For + slices, you must use a non-negative start, a non-negative end, and + a step of 1. + + Returns: + bigframes.series.Series: The list at requested index. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + + class StructAccessor: """ Accessor object for structured data properties of the Series values. diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 10565a2552..fe1c8a12ff 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -390,6 +390,7 @@ def to_gbq( index: bool = True, ordering_id: Optional[str] = None, clustering_columns: Union[pd.Index, Iterable[Hashable]] = (), + labels: dict[str, str] = {}, ) -> str: """Write a DataFrame to a BigQuery table. @@ -467,6 +468,9 @@ def to_gbq( clustering order within the Index/DataFrame columns follows the order specified in `clustering_columns`. + labels (dict[str, str], default None): + Specifies table labels within BigQuery + Returns: str: The fully-qualified ID for the written table, in the form diff --git a/third_party/bigframes_vendored/pandas/io/parsers/readers.py b/third_party/bigframes_vendored/pandas/io/parsers/readers.py index 248cf8e0fe..35b2a1982a 100644 --- a/third_party/bigframes_vendored/pandas/io/parsers/readers.py +++ b/third_party/bigframes_vendored/pandas/io/parsers/readers.py @@ -51,8 +51,7 @@ def read_csv( encoding: Optional[str] = None, **kwargs, ): - """Loads DataFrame from comma-separated values (csv) file locally or from - Cloud Storage. + """Loads data from a comma-separated values (csv) file into a DataFrame. The CSV file data will be persisted as a temporary BigQuery table, which can be automatically recycled after the Session is closed. @@ -60,7 +59,8 @@ def read_csv( .. note:: using `engine="bigquery"` will not guarantee the same ordering as the file. Instead, set a serialized index column as the index and sort by - that in the resulting DataFrame. + that in the resulting DataFrame. Only files stored on your local machine + or in Google Cloud Storage are supported. .. note:: For non-bigquery engine, data is inlined in the query SQL if it is diff --git a/third_party/bigframes_vendored/tpch/queries/q7.py b/third_party/bigframes_vendored/tpch/queries/q7.py index 4ea5e6b238..d922efd1e2 100644 --- a/third_party/bigframes_vendored/tpch/queries/q7.py +++ b/third_party/bigframes_vendored/tpch/queries/q7.py @@ -56,14 +56,6 @@ def q(dataset_id: str, session: bigframes.Session): total = bpd.concat([df1, df2]) - # TODO(huanc): TEMPORARY CODE to force a fresh start. Currently, - # combining everything into a single query seems to trigger a bug - # causing incorrect results. This workaround involves writing to and - # then reading from BigQuery. Remove this once b/355714291 is - # resolved. - dest = total.to_gbq() - total = bpd.read_gbq(dest) - total = total[(total["L_SHIPDATE"] >= var3) & (total["L_SHIPDATE"] <= var4)] total["VOLUME"] = total["L_EXTENDEDPRICE"] * (1.0 - total["L_DISCOUNT"]) total["L_YEAR"] = total["L_SHIPDATE"].dt.year