Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f3cb4ad

Browse filesBrowse files
feat(bigframes): update ai.if_() params to match the SQL version (#16857)
Reference: https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-if I left out the `embeddings` param to keep things simple. It will be introduced later. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 8fa0f81 commit f3cb4ad
Copy full SHA for f3cb4ad

10 files changed

+71-10Lines changed: 71 additions & 10 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎packages/bigframes/bigframes/bigquery/_operations/ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/bigquery/_operations/ai.py
+19-7Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,9 @@ def if_(
817817
prompt: PROMPT_TYPE,
818818
*,
819819
connection_id: str | None = None,
820+
endpoint: str | None = None,
821+
optimization_mode: Literal["minimize_cost", "maximize_quality"] = "minimize_cost",
822+
max_error_ratio: float | None = None,
820823
) -> series.Series:
821824
"""
822825
Evaluates the prompt to True or False. Compared to `ai.generate_bool()`, this function
@@ -838,20 +841,26 @@ def if_(
838841
1 Illinois
839842
dtype: string
840843
841-
.. note::
842-
843-
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
844-
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
845-
and might have limited support. For more information, see the launch stage descriptions
846-
(https://cloud.google.com/products#product-launch-stages).
847-
848844
Args:
849845
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
850846
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
851847
or pandas Series.
852848
connection_id (str, optional):
853849
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
854850
If not provided, the query uses your end-user credential.
851+
endpoint (str, optional):
852+
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
853+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
854+
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML dynamically chooses a model based on your query to have the
855+
best cost to quality tradeoff for the task.
856+
optimization_mode (Literal["minimize_cost", "maximize_quality"]):
857+
Specifies the optimization strategy to use. Supported values are:
858+
* "minimize_cost" (default): uses a local, distilled model to process the majority of rows, reducing latency and cost.
859+
* "maximize_quality": always uses the remote LLM for inference.
860+
max_error_ratio (float):
861+
A float value between 0.0 and 1.0 that contains the maximum acceptable ratio of row-level inference failures to
862+
rows processed on this function. If this value is exceeded, then the query fails. The default value is 1.0.
863+
This argument isn't supported when `optimization_mode` is set to "minimize_cost".
855864
856865
Returns:
857866
bigframes.series.Series: A new series of bools.
@@ -863,6 +872,9 @@ def if_(
863872
operator = ai_ops.AIIf(
864873
prompt_context=tuple(prompt_context),
865874
connection_id=connection_id,
875+
endpoint=endpoint,
876+
optimization_mode=optimization_mode,
877+
max_error_ratio=max_error_ratio,
866878
)
867879

868880
return series_list[0]._apply_nary_op(operator, series_list[1:])
Collapse file

‎packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,6 +1983,9 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
19831983
return ai_ops.AIIf(
19841984
_construct_prompt(values, op.prompt_context), # type: ignore
19851985
op.connection_id, # type: ignore
1986+
op.endpoint, # type: ignore
1987+
op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore
1988+
op.max_error_ratio, # type: ignore
19861989
).to_expr()
19871990

19881991

Collapse file

‎packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
139139
expression=sge.JSON(this=sge.Literal.string(value)),
140140
)
141141
)
142+
elif field == "optimization_mode":
143+
args.append(
144+
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
145+
)
146+
elif field == "max_error_ratio":
147+
args.append(sge.Kwarg(this=field, expression=sge.Literal.number(value)))
142148
elif field == "request_type":
143149
args.append(
144150
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
Collapse file

‎packages/bigframes/bigframes/operations/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/operations/ai_ops.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ class AIIf(base_ops.NaryOp):
146146

147147
prompt_context: Tuple[str | None, ...]
148148
connection_id: str | None
149+
endpoint: str | None = None
150+
optimization_mode: str | None = None
151+
max_error_ratio: float | None = None
149152

150153
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
151154
return dtypes.BOOL_DTYPE
Collapse file

‎packages/bigframes/tests/system/small/bigquery/test_ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/tests/system/small/bigquery/test_ai.py
+5-1Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,11 @@ def test_ai_if(session):
323323
s2 = bpd.Series(["fruit", "tree"], session=session)
324324
prompt = (s1, " is a ", s2)
325325

326-
result = bbq.ai.if_(prompt)
326+
result = bbq.ai.if_(
327+
prompt,
328+
optimization_mode="maximize_quality",
329+
max_error_ratio=0.5,
330+
)
327331

328332
assert _contains_no_nulls(result)
329333
assert result.dtype == dtypes.BOOL_DTYPE
Collapse file
+5-1Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
SELECT
2-
AI.IF(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result`
2+
AI.IF(
3+
prompt => (`string_col`, ' is the same as ', `string_col`),
4+
optimization_mode => 'MINIMIZE_COST',
5+
max_error_ratio => 0.5
6+
) AS `result`
37
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
SELECT
22
AI.IF(
33
prompt => (`string_col`, ' is the same as ', `string_col`),
4-
connection_id => 'bigframes-dev.us.bigframes-default-connection'
4+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
5+
optimization_mode => 'MINIMIZE_COST',
6+
max_error_ratio => 0.5
57
) AS `result`
68
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Collapse file
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
SELECT
2+
AI.IF(
3+
prompt => (`string_col`, ' is the same as ', `string_col`),
4+
endpoint => 'gemini-2.5-flash'
5+
) AS `result`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Collapse file

‎packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py
+18Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,24 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
358358
op = ops.AIIf(
359359
prompt_context=(None, " is the same as ", None),
360360
connection_id=connection_id,
361+
optimization_mode="minimize_cost",
362+
max_error_ratio=0.5,
363+
)
364+
365+
sql = utils._apply_ops_to_sql(
366+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
367+
)
368+
369+
snapshot.assert_match(sql, "out.sql")
370+
371+
372+
def test_ai_if_with_endpoint(scalar_types_df: dataframe.DataFrame, snapshot):
373+
col_name = "string_col"
374+
375+
op = ops.AIIf(
376+
prompt_context=(None, " is the same as ", None),
377+
connection_id=None,
378+
endpoint="gemini-2.5-flash",
361379
)
362380

363381
sql = utils._apply_ops_to_sql(
Collapse file

‎packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ class AIIf(Value):
138138

139139
prompt: Value
140140
connection_id: Optional[Value[dt.String]]
141+
endpoint: Optional[Value[dt.String]] = None
142+
optimization_mode: Optional[Value[dt.String]] = None
143+
max_error_ratio: Optional[Value[dt.Float64]] = None
141144

142145
shape = rlz.shape_like("prompt")
143146

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.