Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit 5bd0029

Browse filesBrowse files
authored
feat: add bigquery.ai.generate_text function (#2433)
* Added the API for ai.generate_text * Fixed SQL syntax bug of ai.generate_embedding * Refactored the code base to keep util functions organized. Fixes b/481092205
1 parent e6de52d commit 5bd0029
Copy full SHA for 5bd0029

7 files changed

+452-119Lines changed: 452 additions & 119 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎bigframes/bigquery/_operations/ai.py‎

Copy file name to clipboardExpand all lines: bigframes/bigquery/_operations/ai.py
+143-21Lines changed: 143 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from bigframes import clients, dataframe, dtypes
2727
from bigframes import pandas as bpd
2828
from bigframes import series, session
29+
from bigframes.bigquery._operations import utils as bq_utils
2930
from bigframes.core import convert
3031
from bigframes.core.logging import log_adapter
3132
import bigframes.core.sql.literals
@@ -391,7 +392,7 @@ def generate_double(
391392

392393
@log_adapter.method_logger(custom_base_name="bigquery_ai")
393394
def generate_embedding(
394-
model_name: str,
395+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
395396
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
396397
*,
397398
output_dimensionality: Optional[int] = None,
@@ -415,9 +416,8 @@ def generate_embedding(
415416
... ) # doctest: +SKIP
416417
417418
Args:
418-
model_name (str):
419-
The name of a remote model from Vertex AI, such as the
420-
multimodalembedding@001 model.
419+
model (bigframes.ml.base.BaseEstimator or str):
420+
The model to use for text embedding.
421421
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
422422
The data to generate embeddings for. If a Series is provided, it is
423423
treated as the 'content' column. If a DataFrame is provided, it
@@ -454,20 +454,9 @@ def generate_embedding(
454454
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-generate-embedding#output>`_
455455
for details.
456456
"""
457-
if isinstance(data, (pd.DataFrame, pd.Series)):
458-
data = bpd.read_pandas(data)
459-
460-
if isinstance(data, series.Series):
461-
data = data.copy()
462-
data.name = "content"
463-
data_df = data.to_frame()
464-
elif isinstance(data, dataframe.DataFrame):
465-
data_df = data
466-
else:
467-
raise ValueError(f"Unsupported data type: {type(data)}")
468-
469-
# We need to get the SQL for the input data to pass as a subquery to the TVF
470-
source_sql = data_df.sql
457+
data = _to_dataframe(data, series_rename="content")
458+
model_name, session = bq_utils.get_model_name_and_session(model, data)
459+
table_sql = bq_utils.to_sql(data)
471460

472461
struct_fields: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {}
473462
if output_dimensionality is not None:
@@ -488,12 +477,128 @@ def generate_embedding(
488477
SELECT *
489478
FROM AI.GENERATE_EMBEDDING(
490479
MODEL `{model_name}`,
491-
({source_sql}),
492-
{bigframes.core.sql.literals.struct_literal(struct_fields)})
480+
({table_sql}),
481+
{bigframes.core.sql.literals.struct_literal(struct_fields)}
493482
)
494483
"""
495484

496-
return data_df._session.read_gbq(query)
485+
if session is None:
486+
return bpd.read_gbq_query(query)
487+
else:
488+
return session.read_gbq_query(query)
489+
490+
491+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
492+
def generate_text(
493+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
494+
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
495+
*,
496+
temperature: Optional[float] = None,
497+
max_output_tokens: Optional[int] = None,
498+
top_k: Optional[int] = None,
499+
top_p: Optional[float] = None,
500+
stop_sequences: Optional[List[str]] = None,
501+
ground_with_google_search: Optional[bool] = None,
502+
request_type: Optional[str] = None,
503+
) -> dataframe.DataFrame:
504+
"""
505+
Generates text using a BigQuery ML model.
506+
507+
See the `BigQuery ML GENERATE_TEXT function syntax
508+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text>`_
509+
for additional reference.
510+
511+
**Examples:**
512+
513+
>>> import bigframes.pandas as bpd
514+
>>> import bigframes.bigquery as bbq
515+
>>> df = bpd.DataFrame({"prompt": ["write a poem about apples"]})
516+
>>> bbq.ai.generate_text(
517+
... "project.dataset.model_name",
518+
... df
519+
... ) # doctest: +SKIP
520+
521+
Args:
522+
model (bigframes.ml.base.BaseEstimator or str):
523+
The model to use for text generation.
524+
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
525+
The data to generate embeddings for. If a Series is provided, it is
526+
treated as the 'content' column. If a DataFrame is provided, it
527+
must contain a 'content' column, or you must rename the column you
528+
wish to embed to 'content'.
529+
temperature (float, optional):
530+
A FLOAT64 value that is used for sampling promiscuity. The value
531+
must be in the range ``[0.0, 1.0]``. A lower temperature works well
532+
for prompts that expect a more deterministic and less open-ended
533+
or creative response, while a higher temperature can lead to more
534+
diverse or creative results. A temperature of ``0`` is
535+
deterministic, meaning that the highest probability response is
536+
always selected.
537+
max_output_tokens (int, optional):
538+
An INT64 value that sets the maximum number of tokens in the
539+
generated text.
540+
top_k (int, optional):
541+
An INT64 value that changes how the model selects tokens for
542+
output. A ``top_k`` of ``1`` means the next selected token is the
543+
most probable among all tokens in the model's vocabulary. A
544+
``top_k`` of ``3`` means that the next token is selected from
545+
among the three most probable tokens by using temperature. The
546+
default value is ``40``.
547+
top_p (float, optional):
548+
A FLOAT64 value that changes how the model selects tokens for
549+
output. Tokens are selected from most probable to least probable
550+
until the sum of their probabilities equals the ``top_p`` value.
551+
For example, if tokens A, B, and C have a probability of 0.3, 0.2,
552+
and 0.1 and the ``top_p`` value is ``0.5``, then the model will
553+
select either A or B as the next token by using temperature. The
554+
default value is ``0.95``.
555+
stop_sequences (List[str], optional):
556+
An ARRAY<STRING> value that contains the stop sequences for the model.
557+
ground_with_google_search (bool, optional):
558+
A BOOL value that determines whether to ground the model with Google Search.
559+
request_type (str, optional):
560+
A STRING value that contains the request type for the model.
561+
562+
Returns:
563+
bigframes.pandas.DataFrame:
564+
The generated text.
565+
"""
566+
data = _to_dataframe(data, series_rename="prompt")
567+
model_name, session = bq_utils.get_model_name_and_session(model, data)
568+
table_sql = bq_utils.to_sql(data)
569+
570+
struct_fields: Dict[
571+
str,
572+
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
573+
] = {}
574+
if temperature is not None:
575+
struct_fields["TEMPERATURE"] = temperature
576+
if max_output_tokens is not None:
577+
struct_fields["MAX_OUTPUT_TOKENS"] = max_output_tokens
578+
if top_k is not None:
579+
struct_fields["TOP_K"] = top_k
580+
if top_p is not None:
581+
struct_fields["TOP_P"] = top_p
582+
if stop_sequences is not None:
583+
struct_fields["STEP_SEQUENCES"] = stop_sequences
584+
if ground_with_google_search is not None:
585+
struct_fields["GROUND_WITH_GOOGLE_SEARCH"] = ground_with_google_search
586+
if request_type is not None:
587+
struct_fields["REQUEST_TYPE"] = request_type
588+
589+
query = f"""
590+
SELECT *
591+
FROM AI.GENERATE_TEXT(
592+
MODEL `{model_name}`,
593+
({table_sql}),
594+
{bigframes.core.sql.literals.struct_literal(struct_fields)}
595+
)
596+
"""
597+
598+
if session is None:
599+
return bpd.read_gbq_query(query)
600+
else:
601+
return session.read_gbq_query(query)
497602

498603

499604
@log_adapter.method_logger(custom_base_name="bigquery_ai")
@@ -811,3 +916,20 @@ def _resolve_connection_id(series: series.Series, connection_id: str | None):
811916
series._session._project,
812917
series._session._location,
813918
)
919+
920+
921+
def _to_dataframe(
922+
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
923+
series_rename: str,
924+
) -> dataframe.DataFrame:
925+
if isinstance(data, (pd.DataFrame, pd.Series)):
926+
data = bpd.read_pandas(data)
927+
928+
if isinstance(data, series.Series):
929+
data = data.copy()
930+
data.name = series_rename
931+
return data.to_frame()
932+
elif isinstance(data, dataframe.DataFrame):
933+
return data
934+
935+
raise ValueError(f"Unsupported data type: {type(data)}")
Collapse file

‎bigframes/bigquery/_operations/ml.py‎

Copy file name to clipboardExpand all lines: bigframes/bigquery/_operations/ml.py
+21-63Lines changed: 21 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -14,66 +14,20 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import cast, List, Mapping, Optional, Union
17+
from typing import List, Mapping, Optional, Union
1818

1919
import bigframes_vendored.constants
2020
import google.cloud.bigquery
2121
import pandas as pd
2222

23+
from bigframes.bigquery._operations import utils
2324
import bigframes.core.logging.log_adapter as log_adapter
2425
import bigframes.core.sql.ml
2526
import bigframes.dataframe as dataframe
2627
import bigframes.ml.base
2728
import bigframes.session
2829

2930

30-
# Helper to convert DataFrame to SQL string
31-
def _to_sql(df_or_sql: Union[pd.DataFrame, dataframe.DataFrame, str]) -> str:
32-
import bigframes.pandas as bpd
33-
34-
if isinstance(df_or_sql, str):
35-
return df_or_sql
36-
37-
if isinstance(df_or_sql, pd.DataFrame):
38-
bf_df = bpd.read_pandas(df_or_sql)
39-
else:
40-
bf_df = cast(dataframe.DataFrame, df_or_sql)
41-
42-
# Cache dataframes to make sure base table is not a snapshot.
43-
# Cached dataframe creates a full copy, never uses snapshot.
44-
# This is a workaround for internal issue b/310266666.
45-
bf_df.cache()
46-
sql, _, _ = bf_df._to_sql_query(include_index=False)
47-
return sql
48-
49-
50-
def _get_model_name_and_session(
51-
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
52-
# Other dataframe arguments to extract session from
53-
*dataframes: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]],
54-
) -> tuple[str, Optional[bigframes.session.Session]]:
55-
if isinstance(model, pd.Series):
56-
try:
57-
model_ref = model["modelReference"]
58-
model_name = f"{model_ref['projectId']}.{model_ref['datasetId']}.{model_ref['modelId']}" # type: ignore
59-
except KeyError:
60-
raise ValueError("modelReference must be present in the pandas Series.")
61-
elif isinstance(model, str):
62-
model_name = model
63-
else:
64-
if model._bqml_model is None:
65-
raise ValueError("Model must be fitted to be used in ML operations.")
66-
return model._bqml_model.model_name, model._bqml_model.session
67-
68-
session = None
69-
for df in dataframes:
70-
if isinstance(df, dataframe.DataFrame):
71-
session = df._session
72-
break
73-
74-
return model_name, session
75-
76-
7731
def _get_model_metadata(
7832
*,
7933
bqclient: google.cloud.bigquery.Client,
@@ -143,8 +97,12 @@ def create_model(
14397
"""
14498
import bigframes.pandas as bpd
14599

146-
training_data_sql = _to_sql(training_data) if training_data is not None else None
147-
custom_holiday_sql = _to_sql(custom_holiday) if custom_holiday is not None else None
100+
training_data_sql = (
101+
utils.to_sql(training_data) if training_data is not None else None
102+
)
103+
custom_holiday_sql = (
104+
utils.to_sql(custom_holiday) if custom_holiday is not None else None
105+
)
148106

149107
# Determine session from DataFrames if not provided
150108
if session is None:
@@ -227,8 +185,8 @@ def evaluate(
227185
"""
228186
import bigframes.pandas as bpd
229187

230-
model_name, session = _get_model_name_and_session(model, input_)
231-
table_sql = _to_sql(input_) if input_ is not None else None
188+
model_name, session = utils.get_model_name_and_session(model, input_)
189+
table_sql = utils.to_sql(input_) if input_ is not None else None
232190

233191
sql = bigframes.core.sql.ml.evaluate(
234192
model_name=model_name,
@@ -281,8 +239,8 @@ def predict(
281239
"""
282240
import bigframes.pandas as bpd
283241

284-
model_name, session = _get_model_name_and_session(model, input_)
285-
table_sql = _to_sql(input_)
242+
model_name, session = utils.get_model_name_and_session(model, input_)
243+
table_sql = utils.to_sql(input_)
286244

287245
sql = bigframes.core.sql.ml.predict(
288246
model_name=model_name,
@@ -340,8 +298,8 @@ def explain_predict(
340298
"""
341299
import bigframes.pandas as bpd
342300

343-
model_name, session = _get_model_name_and_session(model, input_)
344-
table_sql = _to_sql(input_)
301+
model_name, session = utils.get_model_name_and_session(model, input_)
302+
table_sql = utils.to_sql(input_)
345303

346304
sql = bigframes.core.sql.ml.explain_predict(
347305
model_name=model_name,
@@ -383,7 +341,7 @@ def global_explain(
383341
"""
384342
import bigframes.pandas as bpd
385343

386-
model_name, session = _get_model_name_and_session(model)
344+
model_name, session = utils.get_model_name_and_session(model)
387345
sql = bigframes.core.sql.ml.global_explain(
388346
model_name=model_name,
389347
class_level_explain=class_level_explain,
@@ -419,8 +377,8 @@ def transform(
419377
"""
420378
import bigframes.pandas as bpd
421379

422-
model_name, session = _get_model_name_and_session(model, input_)
423-
table_sql = _to_sql(input_)
380+
model_name, session = utils.get_model_name_and_session(model, input_)
381+
table_sql = utils.to_sql(input_)
424382

425383
sql = bigframes.core.sql.ml.transform(
426384
model_name=model_name,
@@ -500,8 +458,8 @@ def generate_text(
500458
"""
501459
import bigframes.pandas as bpd
502460

503-
model_name, session = _get_model_name_and_session(model, input_)
504-
table_sql = _to_sql(input_)
461+
model_name, session = utils.get_model_name_and_session(model, input_)
462+
table_sql = utils.to_sql(input_)
505463

506464
sql = bigframes.core.sql.ml.generate_text(
507465
model_name=model_name,
@@ -565,8 +523,8 @@ def generate_embedding(
565523
"""
566524
import bigframes.pandas as bpd
567525

568-
model_name, session = _get_model_name_and_session(model, input_)
569-
table_sql = _to_sql(input_)
526+
model_name, session = utils.get_model_name_and_session(model, input_)
527+
table_sql = utils.to_sql(input_)
570528

571529
sql = bigframes.core.sql.ml.generate_embedding(
572530
model_name=model_name,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.