Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fcb4579

Browse filesBrowse files
authored
feat(bigframes): implement ai.embed (#16759)
Fixes b/497836685 🦕
1 parent ef3940a commit fcb4579
Copy full SHA for fcb4579

15 files changed

+347-1Lines changed: 347 additions & 1 deletion

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎packages/bigframes/bigframes/bigquery/_operations/ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/bigquery/_operations/ai.py
+107Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,113 @@ def generate_table(
705705
return session.read_gbq_query(query)
706706

707707

708+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
709+
def embed(
710+
content: str | series.Series | pd.Series,
711+
*,
712+
endpoint: str | None = None,
713+
model: str | None = None,
714+
task_type: (
715+
Literal[
716+
"retrieval_query",
717+
"retrieval_document",
718+
"semantic_similarity",
719+
"classification",
720+
"clustering",
721+
"question_answering",
722+
"fact_verification",
723+
"code_retrieval_query",
724+
]
725+
| None
726+
) = None,
727+
title: str | None = None,
728+
model_params: Mapping[Any, Any] | None = None,
729+
connection_id: str | None = None,
730+
) -> series.Series:
731+
"""
732+
Creates embeddings from text or image data in BigQuery.
733+
734+
**Examples:**
735+
736+
>>> import bigframes.pandas as bpd
737+
>>> import bigframes.bigquery as bbq
738+
>>> bbq.ai.embed("dog", endpoint="text-embedding-005") # doctest: +SKIP
739+
0 {'result': array([ 1.78243860e-03, -1.10658340...
740+
741+
>>> s = bpd.Series(['dog']) # doctest: +SKIP
742+
>>> bbq.ai.embed(s, endpoint='text-embedding-005') # doctest: +SKIP
743+
0 {'result': array([ 1.78243860e-03, -1.10658340...
744+
745+
Args:
746+
content (str | Series):
747+
A string literal or a Series (either BigFrames series or pandas Series) that provides the text or image to embed.
748+
endpoint (str, optional):
749+
A string value that specifies a supported Vertex AI embedding model endpoint to use.
750+
The endpoint value that you specify must include the model version, for example,
751+
`"text-embedding-005"`. If you specify this parameter, you can't specify the
752+
`model` parameter.
753+
model (str, optional):
754+
A string value that specifies a built-in embedding model. The only supported value is
755+
`"embeddinggemma-300m"`. If you specify this parameter, you can't specify the `endpoint`,
756+
`title`, `model_params`, or `connection_id` parameters.
757+
task_type (str, optional):
758+
A string literal that specifies the intended downstream application to help the model
759+
produce better quality embeddings. Accepts `"retrieval_query"`, `"retrieval_document"`,
760+
`"semantic_similarity"`, `"classification"`, `"clustering"`, `"question_answering"`,
761+
`"fact_verification"`, `"code_retrieval_query"`.
762+
title (str, optional):
763+
A string value that specifies the document title, which the model uses to improve
764+
embedding quality. You can only use this parameter if you specify `"retrieval_document"`
765+
for the `task_type` value.
766+
model_params (Mapping[Any, Any], optional):
767+
A JSON literal that provides additional parameters to the model. For example,
768+
`{"outputDimensionality": 768}` lets you specify the number of dimensions to use when
769+
generating embeddings.
770+
connection_id (str, optional):
771+
A STRING value specifying the connection to use to communicate with the model, in the
772+
format `PROJECT_ID.LOCATION.CONNECTION_ID`. For example, `myproject.us.myconnection`.
773+
If not provided, the query uses your end-user credential.
774+
775+
Returns:
776+
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
777+
* "result": an ARRAY<FLOAT64> value containing the generated embeddings.
778+
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
779+
"""
780+
781+
if model is not None:
782+
if any(x is not None for x in [endpoint, title, model_params, connection_id]):
783+
raise ValueError(
784+
"You cannot specify endpoint, title, model_params, or connection_id when the model is set."
785+
)
786+
elif endpoint is None:
787+
raise ValueError(
788+
"You must specify exactly one of 'endpoint' or 'model' argument."
789+
)
790+
791+
if title is not None and task_type != "retrieval_document":
792+
raise ValueError(
793+
"You can only use 'title' parameter if you specify retrieval_document for the task_type value."
794+
)
795+
796+
operator = ai_ops.AIEmbed(
797+
endpoint=endpoint,
798+
model=model,
799+
task_type=task_type,
800+
title=title,
801+
model_params=json.dumps(model_params) if model_params else None,
802+
connection_id=connection_id,
803+
)
804+
805+
if isinstance(content, str):
806+
return series.Series([content])._apply_unary_op(operator)
807+
elif isinstance(content, pd.Series):
808+
return series.Series(content)._apply_unary_op(operator)
809+
elif isinstance(content, series.Series):
810+
return content._apply_unary_op(operator)
811+
else:
812+
raise ValueError(f"Unsupported 'content' parameter type: {type(content)}")
813+
814+
708815
@log_adapter.method_logger(custom_base_name="bigquery_ai")
709816
def if_(
710817
prompt: PROMPT_TYPE,
Collapse file

‎packages/bigframes/bigframes/bigquery/ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/bigquery/ai.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858

5959
from bigframes.bigquery._operations.ai import (
6060
classify,
61+
embed,
6162
forecast,
6263
generate,
6364
generate_bool,
@@ -72,6 +73,7 @@
7273

7374
__all__ = [
7475
"classify",
76+
"embed",
7577
"forecast",
7678
"generate",
7779
"generate_bool",
Collapse file

‎packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py
+13Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,19 @@ def ai_generate_double(
19651965
).to_expr()
19661966

19671967

1968+
@scalar_op_compiler.register_unary_op(ops.AIEmbed, pass_op=True)
1969+
def ai_embed(value: ibis_types.Value, op: ops.AIEmbed) -> ibis_types.StructValue:
1970+
return ai_ops.AIEmbed(
1971+
value, # type: ignore
1972+
connection_id=op.connection_id, # type: ignore
1973+
endpoint=op.endpoint, # type: ignore
1974+
model=op.model, # type: ignore
1975+
task_type=op.task_type.upper() if op.task_type is not None else None, # type: ignore
1976+
title=op.title, # type: ignore
1977+
model_params=op.model_params, # type: ignore
1978+
).to_expr()
1979+
1980+
19681981
@scalar_op_compiler.register_nary_op(ops.AIIf, pass_op=True)
19691982
def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
19701983
return ai_ops.AIIf(
Collapse file

‎packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py
+10-1Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from dataclasses import asdict
18+
from typing import Any
1819

1920
import bigframes_vendored.sqlglot.expressions as sge
2021

@@ -23,6 +24,7 @@
2324
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2425

2526
register_nary_op = expression_compiler.expression_compiler.register_nary_op
27+
register_unary_op = expression_compiler.expression_compiler.register_unary_op
2628

2729

2830
@register_nary_op(ops.AIGenerate, pass_op=True)
@@ -53,6 +55,13 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression:
5355
return sge.func("AI.GENERATE_DOUBLE", *args)
5456

5557

58+
@register_unary_op(ops.AIEmbed, pass_op=True)
59+
def _(expr: TypedExpr, op: ops.AIEmbed) -> sge.Expression:
60+
args: list[Any] = [expr.expr] + _construct_named_args(op)
61+
62+
return sge.func("AI.EMBED", *args)
63+
64+
5665
@register_nary_op(ops.AIIf, pass_op=True)
5766
def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
5867
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
@@ -94,7 +103,7 @@ def _construct_prompt(
94103
return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt))
95104

96105

97-
def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
106+
def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
98107
args = []
99108

100109
op_args = asdict(op)
Collapse file

‎packages/bigframes/bigframes/operations/__init__.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/operations/__init__.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from bigframes.operations.ai_ops import (
1818
AIClassify,
19+
AIEmbed,
1920
AIGenerate,
2021
AIGenerateBool,
2122
AIGenerateDouble,
@@ -434,6 +435,7 @@
434435
"AIGenerateBool",
435436
"AIGenerateDouble",
436437
"AIGenerateInt",
438+
"AIEmbed",
437439
"AIIf",
438440
"AIScore",
439441
# Numpy ops mapping
Collapse file

‎packages/bigframes/bigframes/operations/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/operations/ai_ops.py
+22Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,28 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
118118
)
119119

120120

121+
@dataclasses.dataclass(frozen=True)
122+
class AIEmbed(base_ops.UnaryOp):
123+
name: ClassVar[str] = "ai_embed"
124+
125+
endpoint: str | None
126+
model: str | None
127+
task_type: str | None
128+
title: str | None
129+
model_params: str | None
130+
connection_id: str | None
131+
132+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
133+
return pd.ArrowDtype(
134+
pa.struct(
135+
(
136+
pa.field("result", pa.list_(pa.float64())),
137+
pa.field("status", pa.string()),
138+
)
139+
)
140+
)
141+
142+
121143
@dataclasses.dataclass(frozen=True)
122144
class AIIf(base_ops.NaryOp):
123145
name: ClassVar[str] = "ai_if"
Collapse file

‎packages/bigframes/tests/system/small/bigquery/test_ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/tests/system/small/bigquery/test_ai.py
+63Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,69 @@ def test_ai_generate_double_multi_model(session):
255255
)
256256

257257

258+
def test_ai_embed_series_content(session):
259+
content = bpd.Series(["dog"], session=session)
260+
261+
result = bbq.ai.embed(content, endpoint="text-embedding-005")
262+
263+
assert _contains_no_nulls(result)
264+
assert result.dtype == pd.ArrowDtype(
265+
pa.struct(
266+
(
267+
pa.field("result", pa.list_(pa.float64())),
268+
pa.field("status", pa.string()),
269+
)
270+
)
271+
)
272+
273+
274+
def test_ai_embed_string_content(session):
275+
with mock.patch(
276+
"bigframes.core.global_session.get_global_session"
277+
) as mock_get_session:
278+
mock_get_session.return_value = session
279+
280+
result = bbq.ai.embed("dog", endpoint="text-embedding-005")
281+
282+
assert _contains_no_nulls(result)
283+
assert result.dtype == pd.ArrowDtype(
284+
pa.struct(
285+
(
286+
pa.field("result", pa.list_(pa.float64())),
287+
pa.field("status", pa.string()),
288+
)
289+
)
290+
)
291+
292+
293+
def test_ai_embed_no_endpoint_or_model_raises_error(session):
294+
content = bpd.Series(["dog"], session=session)
295+
296+
with pytest.raises(ValueError):
297+
bbq.ai.embed(content)
298+
299+
300+
def test_ai_embed_both_model_and_endpoint_are_set_raises_error(session):
301+
content = bpd.Series(["dog"], session=session)
302+
303+
with pytest.raises(ValueError):
304+
bbq.ai.embed(
305+
content, endpoint="text-embedding-005", model="embeddinggemma-300m model"
306+
)
307+
308+
309+
def test_ai_embed_title_and_task_type_mismatch_raises_error(session):
310+
content = bpd.Series(["dog"], session=session)
311+
312+
with pytest.raises(ValueError):
313+
bbq.ai.embed(
314+
content,
315+
endpoint="text-embedding-005",
316+
title="my title",
317+
task_type="text_similarity",
318+
)
319+
320+
258321
def test_ai_if(session):
259322
s1 = bpd.Series(["apple", "bear"], session=session)
260323
s2 = bpd.Series(["fruit", "tree"], session=session)
Collapse file
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.EMBED(`string_col`, endpoint => 'text-embedding-005') AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SELECT
2+
AI.EMBED(
3+
`string_col`,
4+
endpoint => 'text-embedding-005',
5+
connection_id => 'bigframes-dev.us.bigframes-default-connection'
6+
) AS `result`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Collapse file
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.EMBED(`string_col`, model => 'embeddinggemma-300m') AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.