Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e9c52b1

Browse filesBrowse files
authored
feat(bigframes): add more params to ai.classify (#16990)
1 parent cef659d commit e9c52b1
Copy full SHA for e9c52b1

8 files changed

+109-17Lines changed: 109 additions & 17 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎packages/bigframes/bigframes/bigquery/_operations/ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/bigquery/_operations/ai.py
+26-16Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import pandas as pd
2525

26-
from bigframes import clients, dataframe, dtypes, series, session
26+
from bigframes import dataframe, dtypes, series, session
2727
from bigframes import pandas as bpd
2828
from bigframes.bigquery._operations import utils as bq_utils
2929
from bigframes.core import convert
@@ -885,7 +885,11 @@ def classify(
885885
input: PROMPT_TYPE,
886886
categories: tuple[str, ...] | list[str],
887887
*,
888+
examples: list[tuple[str, str]] | None = None,
888889
connection_id: str | None = None,
890+
endpoint: str | None = None,
891+
optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None,
892+
max_error_ratio: float | None = None,
889893
) -> series.Series:
890894
"""
891895
Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input.
@@ -903,22 +907,30 @@ def classify(
903907
<BLANKLINE>
904908
[2 rows x 2 columns]
905909
906-
.. note::
907-
908-
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
909-
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
910-
and might have limited support. For more information, see the launch stage descriptions
911-
(https://cloud.google.com/products#product-launch-stages).
912-
913910
Args:
914911
input (str | Series | List[str|Series] | Tuple[str|Series, ...]):
915912
A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
916913
or pandas Series.
917914
categories (tuple[str, ...] | list[str]):
918915
Categories to classify the input into.
916+
examples (list[tuple[str, str]], optional):
917+
An array that contains representative examples of input strings and the output category
918+
that you expect. You can provide examples to help the model understand your
919+
intended threshold for a condition with nuanced or subjective logic. We recommend providing at most 5 examples.
919920
connection_id (str, optional):
920921
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
921922
If not provided, the query uses your end-user credential.
923+
endpoint (str, optional):
924+
A STRING value that specifies the Vertex AI endpoint to use for the model. You can specify any
925+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically
926+
identifies and uses the full endpoint of the model.
927+
optimization_mode (Literal["minimize_cost", "maximize_quality"], optional):
928+
A STRING value that specifies the optimization strategy to use. Supported values are ``minimize_cost``
929+
and ``maximize_quality``.
930+
max_error_ratio (float, optional):
931+
A value between ``0.0`` and ``1.0`` that contains the maximum acceptable ratio of row-level
932+
inference failures to rows processed on this function. The default value is 1.0.
933+
This argument isn't supported when ``optimization_mode`` is set to ``minimize_cost``.
922934
923935
Returns:
924936
bigframes.series.Series: A new series of strings.
@@ -927,10 +939,16 @@ def classify(
927939
prompt_context, series_list = _separate_context_and_series(input)
928940
assert len(series_list) > 0
929941

942+
example_tuples = tuple(examples) if examples is not None else None
943+
930944
operator = ai_ops.AIClassify(
931945
prompt_context=tuple(prompt_context),
932946
categories=tuple(categories),
947+
examples=example_tuples,
933948
connection_id=connection_id,
949+
endpoint=endpoint,
950+
optimization_mode=optimization_mode,
951+
max_error_ratio=max_error_ratio,
934952
)
935953

936954
return series_list[0]._apply_nary_op(operator, series_list[1:])
@@ -1249,14 +1267,6 @@ def _convert_series(
12491267
return result
12501268

12511269

1252-
def _resolve_connection_id(series: series.Series, connection_id: str | None):
1253-
return clients.get_canonical_bq_connection_id(
1254-
connection_id or series._session.bq_connection,
1255-
series._session._project,
1256-
series._session._location,
1257-
)
1258-
1259-
12601270
def _to_dataframe(
12611271
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
12621272
series_rename: str,
Collapse file

‎packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py
+24Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,7 +1996,11 @@ def ai_classify(
19961996
return ai_ops.AIClassify(
19971997
_construct_prompt(values, op.prompt_context), # type: ignore
19981998
op.categories, # type: ignore
1999+
_construct_examples(op.examples), # type: ignore
19992000
op.connection_id, # type: ignore
2001+
op.endpoint, # type: ignore
2002+
op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore
2003+
op.max_error_ratio, # type: ignore
20002004
).to_expr()
20012005

20022006

@@ -2040,6 +2044,26 @@ def _construct_prompt(
20402044
return ibis.struct(prompt)
20412045

20422046

2047+
def _construct_examples(
2048+
examples: tuple[tuple[str, str]] | None,
2049+
) -> ibis_types.ArrayValue | None:
2050+
if examples is None:
2051+
return None
2052+
2053+
results: list[ibis_types.StructValue] = []
2054+
2055+
for example in examples:
2056+
ibis_example = ibis.struct(
2057+
{
2058+
"_field_1": example[0],
2059+
"_field_2": example[1],
2060+
}
2061+
)
2062+
results.append(ibis_example)
2063+
2064+
return ibis.array(results)
2065+
2066+
20432067
@scalar_op_compiler.register_nary_op(ops.RowKey, pass_op=True)
20442068
def rowkey_op_impl(*values: ibis_types.Value, op: ops.RowKey) -> ibis_types.Value:
20452069
return bigframes.core.compile.ibis_compiler.default_ordering.gen_row_key(values)
Collapse file

‎packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
149149
args.append(
150150
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
151151
)
152+
elif field == "examples":
153+
example_expressions = [
154+
sge.Tuple(
155+
expressions=[sge.Literal.string(key), sge.Literal.string(val)]
156+
)
157+
for key, val in value
158+
]
159+
args.append(
160+
sge.Kwarg(this=field, expression=sge.array(*example_expressions))
161+
)
152162
else:
153163
args.append(
154164
sge.Kwarg(this=field, expression=sge.Literal.string(str(value)))
Collapse file

‎packages/bigframes/bigframes/operations/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/operations/ai_ops.py
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,11 @@ class AIClassify(base_ops.NaryOp):
160160

161161
prompt_context: Tuple[str | None, ...]
162162
categories: tuple[str, ...]
163+
examples: tuple[tuple[str, str], ...] | None
163164
connection_id: str | None
165+
endpoint: str | None
166+
optimization_mode: str | None
167+
max_error_ratio: float | None
164168

165169
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
166170
return dtypes.STRING_DTYPE
Collapse file

‎packages/bigframes/tests/system/small/bigquery/test_ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/tests/system/small/bigquery/test_ai.py
+9Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,15 @@ def test_ai_classify(session):
355355
assert result.dtype == dtypes.STRING_DTYPE
356356

357357

358+
def test_ai_classify_with_examples(session):
359+
s = bpd.Series(["cat", "orchid"], session=session)
360+
361+
result = bbq.ai.classify(s, ["animal", "plant"], examples=[("dog", "animal")])
362+
363+
assert len(result) == len(s)
364+
assert result.dtype == dtypes.STRING_DTYPE
365+
366+
358367
def test_ai_classify_multi_model(session, bq_connection):
359368
df = session.from_glob_path(
360369
"gs://bigframes-dev-testing/a_multimodel/images/*",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
SELECT
2+
AI.CLASSIFY(
3+
input => (`string_col`),
4+
categories => ['greeting', 'rejection'],
5+
examples => [('hi', 'greeting'), ('bye', 'rejection')],
6+
endpoint => 'gemini-2.5-flash',
7+
max_error_ratio => 0.1
8+
) AS `result`
9+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Collapse file

‎packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py
+22Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,29 @@ def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot, connection_
392392
op = ops.AIClassify(
393393
prompt_context=(None,),
394394
categories=("greeting", "rejection"),
395+
examples=None,
395396
connection_id=connection_id,
397+
endpoint=None,
398+
optimization_mode=None,
399+
max_error_ratio=None,
400+
)
401+
402+
sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])
403+
404+
snapshot.assert_match(sql, "out.sql")
405+
406+
407+
def test_ai_classify_with_params(scalar_types_df: dataframe.DataFrame, snapshot):
408+
col_name = "string_col"
409+
410+
op = ops.AIClassify(
411+
prompt_context=(None,),
412+
categories=("greeting", "rejection"),
413+
examples=(("hi", "greeting"), ("bye", "rejection")),
414+
connection_id=None,
415+
endpoint="gemini-2.5-flash",
416+
optimization_mode=None,
417+
max_error_ratio=0.1,
396418
)
397419

398420
sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])
Collapse file

‎packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py
+5-1Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,16 @@ class AIClassify(Value):
155155

156156
input: Value
157157
categories: Value[dt.Array[dt.String]]
158+
examples: Optional[Value]
158159
connection_id: Optional[Value[dt.String]]
160+
endpoint: Optional[Value[dt.String]]
161+
optimization_mode: Optional[Value[dt.String]]
162+
max_error_ratio: Optional[Value[dt.Float64]]
159163

160164
shape = rlz.shape_like("input")
161165

162166
@attribute
163-
def dtype(self) -> dt.Struct:
167+
def dtype(self) -> dt.DataType:
164168
return dt.string
165169

166170

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.