Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit f7fd189

Browse filesBrowse files
authored
feat: Update bigquery.ai.generate_table output_schema to allow Mapping type (#2463)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent ca9fb13 commit f7fd189
Copy full SHA for f7fd189

3 files changed

+58-4Lines changed: 58 additions & 4 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎bigframes/bigquery/_operations/ai.py‎

Copy file name to clipboardExpand all lines: bigframes/bigquery/_operations/ai.py
+15-4Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def generate_table(
606606
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
607607
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
608608
*,
609-
output_schema: str,
609+
output_schema: Union[str, Mapping[str, str]],
610610
temperature: Optional[float] = None,
611611
top_p: Optional[float] = None,
612612
max_output_tokens: Optional[int] = None,
@@ -642,8 +642,10 @@ def generate_table(
642642
treated as the 'prompt' column. If a DataFrame is provided, it
643643
must contain a 'prompt' column, or you must rename the column you
644644
wish to generate table to 'prompt'.
645-
output_schema (str):
646-
A string defining the output schema (e.g., "col1 STRING, col2 INT64").
645+
output_schema (str | Mapping[str, str]):
646+
A string defining the output schema (e.g., "col1 STRING, col2 INT64"),
647+
or a mapping value that specifies the schema of the output, in the form {field_name: data_type}.
648+
Supported data types include `STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`.
647649
temperature (float, optional):
648650
A FLOAT64 value that is used for sampling promiscuity. The value
649651
must be in the range ``[0.0, 1.0]``.
@@ -666,8 +668,17 @@ def generate_table(
666668
model_name, session = bq_utils.get_model_name_and_session(model, data)
667669
table_sql = bq_utils.to_sql(data)
668670

671+
if isinstance(output_schema, Mapping):
672+
output_schema_str = ", ".join(
673+
[f"{name} {sql_type}" for name, sql_type in output_schema.items()]
674+
)
675+
# Validate user input
676+
output_schemas.parse_sql_fields(output_schema_str)
677+
else:
678+
output_schema_str = output_schema
679+
669680
struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {
670-
"output_schema": output_schema
681+
"output_schema": output_schema_str
671682
}
672683
if temperature is not None:
673684
struct_fields_bq["temperature"] = temperature
Collapse file

‎tests/system/large/bigquery/test_ai.py‎

Copy file name to clipboardExpand all lines: tests/system/large/bigquery/test_ai.py
+17Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,20 @@ def test_generate_table(text_model):
111111
assert "creator" in result.columns
112112
# The model may not always return the exact number of rows requested.
113113
assert len(result) > 0
114+
115+
116+
def test_generate_table_with_mapping_schema(text_model):
117+
df = bpd.DataFrame(
118+
{"prompt": ["Generate a table of 2 programming languages and their creators."]}
119+
)
120+
121+
result = ai.generate_table(
122+
text_model,
123+
df,
124+
output_schema={"language": "STRING", "creator": "STRING"},
125+
)
126+
127+
assert "language" in result.columns
128+
assert "creator" in result.columns
129+
# The model may not always return the exact number of rows requested.
130+
assert len(result) > 0
Collapse file

‎tests/unit/bigquery/test_ai.py‎

Copy file name to clipboardExpand all lines: tests/unit/bigquery/test_ai.py
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,32 @@ def test_generate_table_with_options(mock_dataframe, mock_session):
269269
)
270270

271271

272+
def test_generate_table_with_mapping_schema(mock_dataframe, mock_session):
273+
model_name = "project.dataset.model"
274+
275+
bbq.ai.generate_table(
276+
model_name,
277+
mock_dataframe,
278+
output_schema={"col1": "STRING", "col2": "INT64"},
279+
)
280+
281+
mock_session.read_gbq_query.assert_called_once()
282+
query = mock_session.read_gbq_query.call_args[0][0]
283+
284+
# Normalize whitespace for comparison
285+
query = " ".join(query.split())
286+
287+
expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE("
288+
expected_part_2 = f"MODEL `{model_name}`,"
289+
expected_part_3 = "(SELECT * FROM my_table),"
290+
expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)"
291+
292+
assert expected_part_1 in query
293+
assert expected_part_2 in query
294+
assert expected_part_3 in query
295+
assert expected_part_4 in query
296+
297+
272298
@mock.patch("bigframes.pandas.read_pandas")
273299
def test_generate_text_with_pandas_dataframe(
274300
read_pandas_mock, mock_dataframe, mock_session

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.