Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit b925aa2

Browse filesBrowse files
authored
feat: add bigquery.ai.generate_table function (#2453)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent a6aafaa commit b925aa2
Copy full SHA for b925aa2

4 files changed

+163Lines changed: 163 additions & 0 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎bigframes/bigquery/_operations/ai.py‎

Copy file name to clipboardExpand all lines: bigframes/bigquery/_operations/ai.py
+95Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,101 @@ def generate_text(
601601
return session.read_gbq_query(query)
602602

603603

604+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
605+
def generate_table(
606+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
607+
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
608+
*,
609+
output_schema: str,
610+
temperature: Optional[float] = None,
611+
top_p: Optional[float] = None,
612+
max_output_tokens: Optional[int] = None,
613+
stop_sequences: Optional[List[str]] = None,
614+
request_type: Optional[str] = None,
615+
) -> dataframe.DataFrame:
616+
"""
617+
Generates a table using a BigQuery ML model.
618+
619+
See the `AI.GENERATE_TABLE function syntax
620+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-table>`_
621+
for additional reference.
622+
623+
**Examples:**
624+
625+
>>> import bigframes.pandas as bpd
626+
>>> import bigframes.bigquery as bbq
627+
>>> # The user is responsible for constructing a DataFrame that contains
628+
>>> # the necessary columns for the model's prompt. For example, a
629+
>>> # DataFrame with a 'prompt' column for text classification.
630+
>>> df = bpd.DataFrame({'prompt': ["some text to classify"]})
631+
>>> result = bbq.ai.generate_table(
632+
... "project.dataset.model_name",
633+
... data=df,
634+
... output_schema="category STRING"
635+
... ) # doctest: +SKIP
636+
637+
Args:
638+
model (bigframes.ml.base.BaseEstimator or str):
639+
The model to use for table generation.
640+
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
641+
The data to generate embeddings for. If a Series is provided, it is
642+
treated as the 'content' column. If a DataFrame is provided, it
643+
must contain a 'content' column, or you must rename the column you
644+
wish to embed to 'content'.
645+
output_schema (str):
646+
A string defining the output schema (e.g., "col1 STRING, col2 INT64").
647+
temperature (float, optional):
648+
A FLOAT64 value that is used for sampling promiscuity. The value
649+
must be in the range ``[0.0, 1.0]``.
650+
top_p (float, optional):
651+
A FLOAT64 value that changes how the model selects tokens for
652+
output.
653+
max_output_tokens (int, optional):
654+
An INT64 value that sets the maximum number of tokens in the
655+
generated table.
656+
stop_sequences (List[str], optional):
657+
An ARRAY<STRING> value that contains the stop sequences for the model.
658+
request_type (str, optional):
659+
A STRING value that contains the request type for the model.
660+
661+
Returns:
662+
bigframes.pandas.DataFrame:
663+
The generated table.
664+
"""
665+
data = _to_dataframe(data, series_rename="prompt")
666+
model_name, session = bq_utils.get_model_name_and_session(model, data)
667+
table_sql = bq_utils.to_sql(data)
668+
669+
struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {
670+
"output_schema": output_schema
671+
}
672+
if temperature is not None:
673+
struct_fields_bq["temperature"] = temperature
674+
if top_p is not None:
675+
struct_fields_bq["top_p"] = top_p
676+
if max_output_tokens is not None:
677+
struct_fields_bq["max_output_tokens"] = max_output_tokens
678+
if stop_sequences is not None:
679+
struct_fields_bq["stop_sequences"] = stop_sequences
680+
if request_type is not None:
681+
struct_fields_bq["request_type"] = request_type
682+
683+
struct_sql = bigframes.core.sql.literals.struct_literal(struct_fields_bq)
684+
query = f"""
685+
SELECT *
686+
FROM AI.GENERATE_TABLE(
687+
MODEL `{model_name}`,
688+
({table_sql}),
689+
{struct_sql}
690+
)
691+
"""
692+
693+
if session is None:
694+
return bpd.read_gbq_query(query)
695+
else:
696+
return session.read_gbq_query(query)
697+
698+
604699
@log_adapter.method_logger(custom_base_name="bigquery_ai")
605700
def if_(
606701
prompt: PROMPT_TYPE,
Collapse file

‎bigframes/bigquery/ai.py‎

Copy file name to clipboardExpand all lines: bigframes/bigquery/ai.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
generate_double,
2525
generate_embedding,
2626
generate_int,
27+
generate_table,
2728
generate_text,
2829
if_,
2930
score,
@@ -37,6 +38,7 @@
3738
"generate_double",
3839
"generate_embedding",
3940
"generate_int",
41+
"generate_table",
4042
"generate_text",
4143
"if_",
4244
"score",
Collapse file

‎tests/system/large/bigquery/test_ai.py‎

Copy file name to clipboardExpand all lines: tests/system/large/bigquery/test_ai.py
+17Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,20 @@ def test_generate_text_with_options(text_model):
9494

9595
# It basically asserts that the results are still returned.
9696
assert len(result) == 2
97+
98+
99+
def test_generate_table(text_model):
100+
df = bpd.DataFrame(
101+
{"prompt": ["Generate a table of 2 programming languages and their creators."]}
102+
)
103+
104+
result = ai.generate_table(
105+
text_model,
106+
df,
107+
output_schema="language STRING, creator STRING",
108+
)
109+
110+
assert "language" in result.columns
111+
assert "creator" in result.columns
112+
# The model may not always return the exact number of rows requested.
113+
assert len(result) > 0
Collapse file

‎tests/unit/bigquery/test_ai.py‎

Copy file name to clipboardExpand all lines: tests/unit/bigquery/test_ai.py
+49Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,55 @@ def test_generate_text_defaults(mock_dataframe, mock_session):
220220
assert "STRUCT()" in query
221221

222222

223+
def test_generate_table_with_dataframe(mock_dataframe, mock_session):
224+
model_name = "project.dataset.model"
225+
226+
bbq.ai.generate_table(
227+
model_name,
228+
mock_dataframe,
229+
output_schema="col1 STRING, col2 INT64",
230+
)
231+
232+
mock_session.read_gbq_query.assert_called_once()
233+
query = mock_session.read_gbq_query.call_args[0][0]
234+
235+
# Normalize whitespace for comparison
236+
query = " ".join(query.split())
237+
238+
expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE("
239+
expected_part_2 = f"MODEL `{model_name}`,"
240+
expected_part_3 = "(SELECT * FROM my_table),"
241+
expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)"
242+
243+
assert expected_part_1 in query
244+
assert expected_part_2 in query
245+
assert expected_part_3 in query
246+
assert expected_part_4 in query
247+
248+
249+
def test_generate_table_with_options(mock_dataframe, mock_session):
250+
model_name = "project.dataset.model"
251+
252+
bbq.ai.generate_table(
253+
model_name,
254+
mock_dataframe,
255+
output_schema="col1 STRING",
256+
temperature=0.5,
257+
max_output_tokens=100,
258+
)
259+
260+
mock_session.read_gbq_query.assert_called_once()
261+
query = mock_session.read_gbq_query.call_args[0][0]
262+
query = " ".join(query.split())
263+
264+
assert f"MODEL `{model_name}`" in query
265+
assert "(SELECT * FROM my_table)" in query
266+
assert (
267+
"STRUCT('col1 STRING' AS output_schema, 0.5 AS temperature, 100 AS max_output_tokens)"
268+
in query
269+
)
270+
271+
223272
@mock.patch("bigframes.pandas.read_pandas")
224273
def test_generate_text_with_pandas_dataframe(
225274
read_pandas_mock, mock_dataframe, mock_session

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.