Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d4afa2c

Browse filesBrowse files
authored
feat(bigframes): implement ai.similarity (#16771)
Fixes b/497837587
1 parent f6e916c commit d4afa2c
Copy full SHA for d4afa2c

14 files changed

+262Lines changed: 262 additions & 0 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎packages/bigframes/bigframes/bigquery/_operations/ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/bigquery/_operations/ai.py
+81Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,87 @@ def score(
976976
return series_list[0]._apply_nary_op(operator, series_list[1:])
977977

978978

979+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
980+
def similarity(
981+
content1: str | series.Series | pd.Series,
982+
content2: str | series.Series | pd.Series,
983+
*,
984+
endpoint: str | None = None,
985+
model: str | None = None,
986+
model_params: Mapping[Any, Any] | None = None,
987+
connection_id: str | None = None,
988+
) -> series.Series:
989+
"""
990+
Returns a FLOAT64 value that represents the cosine similarity between the two inputs.
991+
992+
**Examples:**
993+
994+
>>> import bigframes.pandas as bpd
995+
>>> import bigframes.bigquery as bbq
996+
>>> df = bpd.DataFrame({'word': ['happy', 'sad']})
997+
>>> bbq.ai.similarity(df['word'], 'glad', endpoint='text-embedding-005') # doctest: +SKIP
998+
0 0.916601
999+
1 0.660579
1000+
1001+
Args:
1002+
content1 (str | Series):
1003+
A string or series that provides the first value to compare. Both a BigFrames Series or a pandas Series are allowed.
1004+
content2 (str | Series):
1005+
A string or series that provides the second value to compare. Both a BigFrames Series or a pandas Series are allowed.
1006+
endpoint (str, optional):
1007+
Specifies the Vertex AI endpoint to use for the text embedding model.
1008+
If you specify the model name, such as `'text-embedding-005'`, rather than a URL, then BigQuery ML automatically identifies the model and uses the model's full endpoint.
1009+
model (str, optional):
1010+
Specifies a built-in text embedding model. The only supported value is the embeddinggemma-300m model.
1011+
If you specify this parameter, you can't specify the `endpoint`, `model_params`, or `connection_id` parameters.
1012+
model_params (Mapping[Any, Any], optional):
1013+
Provides additional parameters to the model. You can use any of the parameters object fields.
1014+
One of these fields, `outputDimensionality`, lets you specify the number of dimensions to use when generating embeddings.
1015+
connection_id (str, optional):
1016+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
1017+
1018+
Returns:
1019+
bigframes.series.Series: A new series of FLOAT64 values representing the cosine similarity.
1020+
"""
1021+
if model is not None:
1022+
if any(x is not None for x in [endpoint, model_params, connection_id]):
1023+
raise ValueError(
1024+
"If 'model' is specified, you cannot specify 'endpoint', 'model_params', or 'connection_id'."
1025+
)
1026+
elif endpoint is None:
1027+
raise ValueError("You must specify either 'model' or 'endpoint'.")
1028+
1029+
operator = ai_ops.AISimilarity(
1030+
endpoint=endpoint,
1031+
model=model,
1032+
model_params=json.dumps(model_params) if model_params else None,
1033+
connection_id=connection_id,
1034+
)
1035+
1036+
# Find a unifying session for the subsequent operations.
1037+
bf_session = None
1038+
if isinstance(content1, series.Series):
1039+
bf_session = content1._session
1040+
elif isinstance(content2, series.Series):
1041+
bf_session = content2._session
1042+
1043+
if isinstance(content1, str) and isinstance(content2, str):
1044+
content1 = series.Series([content1], session=bf_session)
1045+
return content1._apply_binary_op(content2, operator)
1046+
elif isinstance(content1, str):
1047+
# content2 must be a series
1048+
content2 = convert.to_bf_series(
1049+
content2, default_index=None, session=bf_session
1050+
)
1051+
return content2._apply_binary_op(content1, operator)
1052+
else:
1053+
# content1 must be a series.
1054+
content1 = convert.to_bf_series(
1055+
content1, default_index=None, session=bf_session
1056+
)
1057+
return content1._apply_binary_op(content2, operator)
1058+
1059+
9791060
@log_adapter.method_logger(custom_base_name="bigquery_ai")
9801061
def forecast(
9811062
df: dataframe.DataFrame | pd.DataFrame,
Collapse file

‎packages/bigframes/bigframes/bigquery/ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/bigquery/ai.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
generate_text,
7070
if_,
7171
score,
72+
similarity,
7273
)
7374

7475
__all__ = [
@@ -84,4 +85,5 @@
8485
"generate_text",
8586
"if_",
8687
"score",
88+
"similarity",
8789
]
Collapse file

‎packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py
+14Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,20 @@ def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructVal
20052005
).to_expr()
20062006

20072007

2008+
@scalar_op_compiler.register_binary_op(ops.AISimilarity, pass_op=True)
2009+
def ai_similarity(
2010+
content1: ibis_types.Value, content2: ibis_types.Value, op: ops.AISimilarity
2011+
) -> ibis_types.Value:
2012+
return ai_ops.AISimilarity(
2013+
content1, # type: ignore
2014+
content2, # type: ignore
2015+
op.endpoint, # type: ignore
2016+
op.model, # type: ignore
2017+
op.model_params, # type: ignore
2018+
op.connection_id, # type: ignore
2019+
).to_expr()
2020+
2021+
20082022
def _construct_prompt(
20092023
col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None]
20102024
) -> ibis_types.StructValue:
Collapse file

‎packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2525

2626
register_nary_op = expression_compiler.expression_compiler.register_nary_op
27+
register_binary_op = expression_compiler.expression_compiler.register_binary_op
2728
register_unary_op = expression_compiler.expression_compiler.register_unary_op
2829

2930

@@ -85,6 +86,16 @@ def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
8586
return sge.func("AI.SCORE", *args)
8687

8788

89+
@register_binary_op(ops.AISimilarity, pass_op=True)
90+
def _(content1: TypedExpr, content2: TypedExpr, op: ops.AISimilarity) -> sge.Expression:
91+
args = [
92+
sge.Kwarg(this="content1", expression=content1.expr),
93+
sge.Kwarg(this="content2", expression=content2.expr),
94+
] + _construct_named_args(op)
95+
96+
return sge.func("AI.SIMILARITY", *args)
97+
98+
8899
def _construct_prompt(
89100
exprs: tuple[TypedExpr, ...],
90101
prompt_context: tuple[str | None, ...],
Collapse file

‎packages/bigframes/bigframes/operations/__init__.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/operations/__init__.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AIGenerateInt,
2424
AIIf,
2525
AIScore,
26+
AISimilarity,
2627
)
2728
from bigframes.operations.array_ops import (
2829
ArrayIndexOp,
@@ -438,6 +439,7 @@
438439
"AIEmbed",
439440
"AIIf",
440441
"AIScore",
442+
"AISimilarity",
441443
# Numpy ops mapping
442444
"NUMPY_TO_BINOP",
443445
"NUMPY_TO_OP",
Collapse file

‎packages/bigframes/bigframes/operations/ai_ops.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/bigframes/operations/ai_ops.py
+13Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,16 @@ class AIScore(base_ops.NaryOp):
172172

173173
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
174174
return dtypes.FLOAT_DTYPE
175+
176+
177+
@dataclasses.dataclass(frozen=True)
178+
class AISimilarity(base_ops.BinaryOp):
179+
name: ClassVar[str] = "ai_similarity"
180+
181+
endpoint: str | None
182+
model: str | None
183+
model_params: str | None
184+
connection_id: str | None
185+
186+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
187+
return dtypes.FLOAT_DTYPE
Collapse file

‎packages/bigframes/tests/system/small/bigquery/test_ai.py‎

Copy file name to clipboardExpand all lines: packages/bigframes/tests/system/small/bigquery/test_ai.py
+48Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,5 +433,53 @@ def test_forecast_w_params(time_series_df_default_index: dataframe.DataFrame):
433433
)
434434

435435

436+
def test_ai_similarity(session):
437+
s1 = bpd.Series(["happy", "sad"], session=session)
438+
s2 = pd.Series(["glad", "angry"])
439+
440+
result = bbq.ai.similarity(s1, s2, endpoint="text-embedding-005")
441+
442+
assert _contains_no_nulls(result)
443+
assert result.dtype == dtypes.FLOAT_DTYPE
444+
445+
446+
def test_ai_similarity_one_content_is_string_literal(session):
447+
s1 = "happy"
448+
s2 = bpd.Series(["glad", "angry"], session=session)
449+
450+
result = bbq.ai.similarity(s1, s2, model="embeddinggemma-300m")
451+
452+
assert _contains_no_nulls(result)
453+
assert result.dtype == dtypes.FLOAT_DTYPE
454+
455+
456+
def test_ai_similarity_both_contents_are_string_literals(session):
457+
s1 = "happy"
458+
s2 = "glad"
459+
460+
result = bbq.ai.similarity(s1, s2, endpoint="text-embedding-005")
461+
462+
assert _contains_no_nulls(result)
463+
assert result.dtype == dtypes.FLOAT_DTYPE
464+
465+
466+
def test_ai_similarity_no_endpoint_or_model__raises_error(session):
467+
s1 = bpd.Series(["happy", "sad"], session=session)
468+
s2 = bpd.Series(["glad", "angry"], session=session)
469+
470+
with pytest.raises(ValueError):
471+
bbq.ai.similarity(s1, s2)
472+
473+
474+
def test_ai_similarity_both_endpoint_and_model__raises_error(session):
475+
s1 = "happy"
476+
s2 = "glad"
477+
478+
with pytest.raises(ValueError):
479+
bbq.ai.similarity(
480+
s1, s2, endpoint="text-embedding-005", model="embeddinggemma-300m"
481+
)
482+
483+
436484
def _contains_no_nulls(s: series.Series) -> bool:
437485
return len(s) == s.count()
Collapse file
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, endpoint => 'text-embedding-005') AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
SELECT
2+
AI.SIMILARITY(
3+
content1 => `string_col`,
4+
content2 => `string_col`,
5+
endpoint => 'text-embedding-005',
6+
connection_id => 'bigframes-dev.us.bigframes-default-connection'
7+
) AS `result`
8+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, model => 'embeddinggemma-300m') AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.