Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9f42fe1

Browse filesBrowse files
feat(bigframes): update ai.score to match its SQL version (#16919)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent d37a953 commit 9f42fe1
Copy full SHA for 9f42fe1

7 files changed

+56-13Lines changed: 56 additions & 13 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎packages/bigframes/bigframes/bigquery/_operations/ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/bigquery/_operations/ai.py
+12-7Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,8 @@ def score(
941941
prompt: PROMPT_TYPE,
942942
*,
943943
connection_id: str | None = None,
944+
endpoint: str | None = None,
945+
max_error_ratio: float | None = None,
944946
) -> series.Series:
945947
"""
946948
Computes a score based on rubrics described in natural language. It will return a double value.
@@ -958,20 +960,21 @@ def score(
958960
2 3.0
959961
dtype: Float64
960962
961-
.. note::
962-
963-
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
964-
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
965-
and might have limited support. For more information, see the launch stage descriptions
966-
(https://cloud.google.com/products#product-launch-stages).
967-
968963
Args:
969964
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
970965
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
971966
or pandas Series.
972967
connection_id (str, optional):
973968
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
974969
If not provided, the query uses your end-user credential.
970+
endpoint (str, optional):
971+
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
972+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
973+
uses the full endpoint of the model. If you don't specify an endpoint value, BigQuery ML dynamically chooses a model
974+
based on your query to have the best cost to quality tradeoff for the task.
975+
max_error_ratio (float, optional):
976+
A value between `0.0` and `1.0` that contains the maximum acceptable ratio of row-level inference failures to
977+
rows processed on this function. If this value is exceeded, then the query fails.
975978
976979
Returns:
977980
bigframes.series.Series: A new series of double (float) values.
@@ -983,6 +986,8 @@ def score(
983986
operator = ai_ops.AIScore(
984987
prompt_context=tuple(prompt_context),
985988
connection_id=connection_id,
989+
endpoint=endpoint,
990+
max_error_ratio=max_error_ratio,
986991
)
987992

988993
return series_list[0]._apply_nary_op(operator, series_list[1:])
Collapse file

‎packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,8 @@ def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructVal
20052005
return ai_ops.AIScore(
20062006
_construct_prompt(values, op.prompt_context), # type: ignore
20072007
op.connection_id, # type: ignore
2008+
op.endpoint, # type: ignore
2009+
op.max_error_ratio, # type: ignore
20082010
).to_expr()
20092011

20102012

Collapse file

‎packages/bigframes/bigframes/operations/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/operations/ai_ops.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ class AIScore(base_ops.NaryOp):
172172

173173
prompt_context: Tuple[str | None, ...]
174174
connection_id: str | None
175+
endpoint: str | None
176+
max_error_ratio: float | None
175177

176178
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
177179
return dtypes.FLOAT_DTYPE
Collapse file

‎packages/bigframes/bigframes/pandas/io/api.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/pandas/io/api.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,8 @@ def from_glob_path(
654654
def _get_bqclient_and_project() -> Tuple[bigquery.Client, str]:
655655
# Address circular imports in doctest due to bigframes/session/__init__.py
656656
# containing a lot of logic and samples.
657-
from bigframes.session import clients
658657
import bigframes._config.auth
658+
from bigframes.session import clients
659659

660660
credentials, project = bigframes._config.auth.resolve_credentials_and_project(
661661
config.options.bigquery
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SELECT
2+
AI.SCORE(
3+
prompt => (`string_col`, ' is the same as ', `string_col`),
4+
endpoint => 'gemini-2.5-flash',
5+
max_error_ratio => 0.5
6+
) AS `result`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Collapse file

‎packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py
+21Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,27 @@ def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id)
407407
op = ops.AIScore(
408408
prompt_context=(None, " is the same as ", None),
409409
connection_id=connection_id,
410+
endpoint=None,
411+
max_error_ratio=None,
412+
)
413+
414+
sql = utils._apply_ops_to_sql(
415+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
416+
)
417+
418+
snapshot.assert_match(sql, "out.sql")
419+
420+
421+
def test_ai_score_with_endpoint_and_max_error_ratio(
422+
scalar_types_df: dataframe.DataFrame, snapshot
423+
):
424+
col_name = "string_col"
425+
426+
op = ops.AIScore(
427+
prompt_context=(None, " is the same as ", None),
428+
connection_id=None,
429+
endpoint="gemini-2.5-flash",
430+
max_error_ratio=0.5,
410431
)
411432

412433
sql = utils._apply_ops_to_sql(
Collapse file

‎packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py
+11-5Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ class AIIf(Value):
138138

139139
prompt: Value
140140
connection_id: Optional[Value[dt.String]]
141-
endpoint: Optional[Value[dt.String]] = None
142-
optimization_mode: Optional[Value[dt.String]] = None
143-
max_error_ratio: Optional[Value[dt.Float64]] = None
141+
endpoint: Optional[Value[dt.String]]
142+
optimization_mode: Optional[Value[dt.String]]
143+
max_error_ratio: Optional[Value[dt.Float64]]
144144

145145
shape = rlz.shape_like("prompt")
146146

@@ -151,7 +151,7 @@ def dtype(self) -> dt.Struct:
151151

152152
@public
153153
class AIClassify(Value):
154-
"""Generate True/False based on the prompt"""
154+
"""Generate categories based on the prompt"""
155155

156156
input: Value
157157
categories: Value[dt.Array[dt.String]]
@@ -166,13 +166,19 @@ def dtype(self) -> dt.Struct:
166166

167167
@public
168168
class AIScore(Value):
169-
"""Generate doubles based on the prompt"""
169+
"""Generate scores based on the prompt"""
170170

171171
prompt: Value
172172
connection_id: Optional[Value[dt.String]]
173+
endpoint: Optional[Value[dt.String]]
174+
max_error_ratio: Optional[Value[dt.Float64]]
173175

174176
shape = rlz.shape_like("prompt")
175177

178+
@attribute
179+
def dtype(self) -> dt.DataType:
180+
return dt.float64
181+
176182

177183
@public
178184
class AISimilarity(Value):

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.