Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit fe94910

Browse filesBrowse files
authored
feat: use EUC for AI IF, CLASSIFY, and SCORE when connection is not provided (#2507)
Fixes b/489038951 🦕
1 parent a5ddcea commit fe94910
Copy full SHA for fe94910

11 files changed

+33-21Lines changed: 33 additions & 21 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
  • bigframes
  • tests/unit/core/compile/sqlglot/expressions
  • third_party/bigframes_vendored/ibis/expr/operations
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎bigframes/bigquery/_operations/ai.py‎

Copy file name to clipboardExpand all lines: bigframes/bigquery/_operations/ai.py
+6-6Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def if_(
745745
or pandas Series.
746746
connection_id (str, optional):
747747
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
748-
If not provided, the connection from the current session will be used.
748+
If not provided, the query uses your end-user credential.
749749
750750
Returns:
751751
bigframes.series.Series: A new series of bools.
@@ -756,7 +756,7 @@ def if_(
756756

757757
operator = ai_ops.AIIf(
758758
prompt_context=tuple(prompt_context),
759-
connection_id=_resolve_connection_id(series_list[0], connection_id),
759+
connection_id=connection_id,
760760
)
761761

762762
return series_list[0]._apply_nary_op(operator, series_list[1:])
@@ -800,7 +800,7 @@ def classify(
800800
Categories to classify the input into.
801801
connection_id (str, optional):
802802
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
803-
If not provided, the connection from the current session will be used.
803+
If not provided, the query uses your end-user credential.
804804
805805
Returns:
806806
bigframes.series.Series: A new series of strings.
@@ -812,7 +812,7 @@ def classify(
812812
operator = ai_ops.AIClassify(
813813
prompt_context=tuple(prompt_context),
814814
categories=tuple(categories),
815-
connection_id=_resolve_connection_id(series_list[0], connection_id),
815+
connection_id=connection_id,
816816
)
817817

818818
return series_list[0]._apply_nary_op(operator, series_list[1:])
@@ -853,7 +853,7 @@ def score(
853853
or pandas Series.
854854
connection_id (str, optional):
855855
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
856-
If not provided, the connection from the current session will be used.
856+
If not provided, the query uses your end-user credential.
857857
858858
Returns:
859859
bigframes.series.Series: A new series of double (float) values.
@@ -864,7 +864,7 @@ def score(
864864

865865
operator = ai_ops.AIScore(
866866
prompt_context=tuple(prompt_context),
867-
connection_id=_resolve_connection_id(series_list[0], connection_id),
867+
connection_id=connection_id,
868868
)
869869

870870
return series_list[0]._apply_nary_op(operator, series_list[1:])
Collapse file

‎bigframes/core/compile/sqlglot/expressions/ai_ops.py‎

Copy file name to clipboardExpand all lines: bigframes/core/compile/sqlglot/expressions/ai_ops.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
113113
)
114114
)
115115

116-
endpoit = op_args.get("endpoint", None)
117-
if endpoit is not None:
118-
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))
116+
endpoint = op_args.get("endpoint", None)
117+
if endpoint is not None:
118+
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoint)))
119119

120120
request_type = op_args.get("request_type", None)
121121
if request_type is not None:
Collapse file

‎bigframes/operations/ai_ops.py‎

Copy file name to clipboardExpand all lines: bigframes/operations/ai_ops.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class AIIf(base_ops.NaryOp):
123123
name: ClassVar[str] = "ai_if"
124124

125125
prompt_context: Tuple[str | None, ...]
126-
connection_id: str
126+
connection_id: str | None
127127

128128
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
129129
return dtypes.BOOL_DTYPE
@@ -135,7 +135,7 @@ class AIClassify(base_ops.NaryOp):
135135

136136
prompt_context: Tuple[str | None, ...]
137137
categories: tuple[str, ...]
138-
connection_id: str
138+
connection_id: str | None
139139

140140
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
141141
return dtypes.STRING_DTYPE
@@ -146,7 +146,7 @@ class AIScore(base_ops.NaryOp):
146146
name: ClassVar[str] = "ai_score"
147147

148148
prompt_context: Tuple[str | None, ...]
149-
connection_id: str
149+
connection_id: str | None
150150

151151
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
152152
return dtypes.FLOAT_DTYPE
Collapse file
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.CLASSIFY(input => (`string_col`), categories => ['greeting', 'rejection']) AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Collapse file

‎…ots/test_ai_ops/test_ai_classify/out.sql‎ ‎….us.bigframes-default-connection/out.sql‎tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql

Copy file name to clipboard
File renamed without changes.
Collapse file
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.IF(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Collapse file

‎…snapshots/test_ai_ops/test_ai_if/out.sql‎ ‎….us.bigframes-default-connection/out.sql‎tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql

Copy file name to clipboard
File renamed without changes.
Collapse file
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.SCORE(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Collapse file

‎…pshots/test_ai_ops/test_ai_score/out.sql‎ ‎….us.bigframes-default-connection/out.sql‎tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql

Copy file name to clipboard
File renamed without changes.
Collapse file

‎tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py‎

Copy file name to clipboardExpand all lines: tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py
+9-6Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,13 @@ def test_ai_generate_double_with_model_param(
281281
snapshot.assert_match(sql, "out.sql")
282282

283283

284-
def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot):
284+
@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID])
285+
def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
285286
col_name = "string_col"
286287

287288
op = ops.AIIf(
288289
prompt_context=(None, " is the same as ", None),
289-
connection_id=CONNECTION_ID,
290+
connection_id=connection_id,
290291
)
291292

292293
sql = utils._apply_ops_to_sql(
@@ -296,26 +297,28 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot):
296297
snapshot.assert_match(sql, "out.sql")
297298

298299

299-
def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot):
300+
@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID])
301+
def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
300302
col_name = "string_col"
301303

302304
op = ops.AIClassify(
303305
prompt_context=(None,),
304306
categories=("greeting", "rejection"),
305-
connection_id=CONNECTION_ID,
307+
connection_id=connection_id,
306308
)
307309

308310
sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])
309311

310312
snapshot.assert_match(sql, "out.sql")
311313

312314

313-
def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot):
315+
@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID])
316+
def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
314317
col_name = "string_col"
315318

316319
op = ops.AIScore(
317320
prompt_context=(None, " is the same as ", None),
318-
connection_id=CONNECTION_ID,
321+
connection_id=connection_id,
319322
)
320323

321324
sql = utils._apply_ops_to_sql(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.