Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit 1126cec

Browse filesBrowse files
feat: add df.bigquery.ai.forecast method to pandas dataframe accessor (#2518)
Adds the `.bigquery.ai.forecast()` method to pandas DataFrame objects, which proxies to `bigframes.bigquery.ai.forecast()`. Added unit tests and mocked session responses. --- *PR created automatically by Jules for task [14604090974587392182](https://jules.google.com/task/14604090974587392182) started by @tswast* --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: tswast <247555+tswast@users.noreply.github.com>
1 parent edceb35 commit 1126cec
Copy full SHA for 1126cec

3 files changed

+137-1Lines changed: 137 additions & 1 deletion

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎bigframes/bigquery/_operations/ai.py‎

Copy file name to clipboardExpand all lines: bigframes/bigquery/_operations/ai.py
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,7 @@ def forecast(
880880
id_cols: Iterable[str] | None = None,
881881
horizon: int = 10,
882882
confidence_level: float = 0.95,
883+
output_historical_time_series: bool = False,
883884
context_window: int | None = None,
884885
) -> dataframe.DataFrame:
885886
"""
@@ -914,6 +915,15 @@ def forecast(
914915
confidence_level (float, default 0.95):
915916
A FLOAT64 value that specifies the percentage of the future values that fall in the prediction interval.
916917
The default value is 0.95. The valid input range is [0, 1).
918+
output_historical_time_series (bool, default False):
919+
A BOOL value that determines whether the input data is returned
920+
along with the forecasted data. Set this argument to TRUE to return
921+
input data. The default value is FALSE.
922+
923+
Returning the input data along with the forecasted data lets you
924+
compare the historical value of the data column with the forecasted
925+
value of the data column, or chart the change in the data column
926+
values over time.
917927
context_window (int, optional):
918928
An int value that specifies the context window length used by BigQuery ML's built-in TimesFM model.
919929
The context window length determines how many of the most recent data points from the input time series are use by the model.
@@ -945,6 +955,7 @@ def forecast(
945955
"timestamp_col": timestamp_col,
946956
"model": model,
947957
"horizon": horizon,
958+
"output_historical_time_series": output_historical_time_series,
948959
"confidence_level": confidence_level,
949960
}
950961
if id_cols:
Collapse file

‎bigframes/extensions/pandas/dataframe_accessor.py‎

Copy file name to clipboardExpand all lines: bigframes/extensions/pandas/dataframe_accessor.py
+87-1Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import cast
15+
from typing import cast, Iterable, Optional
1616

1717
import pandas
1818
import pandas.api.extensions
@@ -21,6 +21,85 @@
2121
import bigframes.pandas as bpd
2222

2323

24+
class AIAccessor:
25+
"""
26+
Pandas DataFrame accessor for BigQuery AI functions.
27+
"""
28+
29+
def __init__(self, pandas_obj: pandas.DataFrame):
30+
self._obj = pandas_obj
31+
32+
def forecast(
33+
self,
34+
*,
35+
data_col: str,
36+
timestamp_col: str,
37+
model: str = "TimesFM 2.0",
38+
id_cols: Optional[Iterable[str]] = None,
39+
horizon: int = 10,
40+
confidence_level: float = 0.95,
41+
context_window: Optional[int] = None,
42+
output_historical_time_series: bool = False,
43+
session=None,
44+
) -> pandas.DataFrame:
45+
"""
46+
Forecast time series at future horizon using BigQuery AI.FORECAST.
47+
48+
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-forecast
49+
50+
Args:
51+
data_col (str):
52+
A str value that specifies the name of the data column. The data column contains the data to forecast.
53+
The data column must use one of the following data types: INT64, NUMERIC and FLOAT64
54+
timestamp_col (str):
55+
A str value that specified the name of the time points column.
56+
The time points column provides the time points used to generate the forecast.
57+
The time points column must use one of the following data types: TIMESTAMP, DATE and DATETIME
58+
model (str, default "TimesFM 2.0"):
59+
A str value that specifies the name of the model. "TimesFM 2.0" and "TimesFM 2.5" are supported.
60+
id_cols (Iterable[str], optional):
61+
An iterable of str value that specifies the names of one or more ID columns. Each ID identifies a unique time series to forecast.
62+
Specify one or more values for this argument in order to forecast multiple time series using a single query.
63+
The columns that you specify must use one of the following data types: STRING, INT64, ARRAY<STRING> and ARRAY<INT64>
64+
horizon (int, default 10):
65+
An int value that specifies the number of time points to forecast. The default value is 10. The valid input range is [1, 10,000].
66+
confidence_level (float, default 0.95):
67+
A FLOAT64 value that specifies the percentage of the future values that fall in the prediction interval.
68+
The default value is 0.95. The valid input range is [0, 1).
69+
context_window (int, optional):
70+
An int value that specifies the context window length used by BigQuery ML's built-in TimesFM model.
71+
The context window length determines how many of the most recent data points from the input time series are use by the model.
72+
If you don't specify a value, the AI.FORECAST function automatically chooses the smallest possible context window length to use
73+
that is still large enough to cover the number of time series data points in your input data.
74+
output_historical_time_series (bool, default False):
75+
A boolean value that determines whether to include the input time series history in the forecast.
76+
session (bigframes.session.Session, optional):
77+
The BigFrames session to use. If not provided, the default global session is used.
78+
79+
Returns:
80+
pandas.DataFrame:
81+
The forecast DataFrame result.
82+
"""
83+
import bigframes.bigquery.ai
84+
85+
if session is None:
86+
session = bf_session.get_global_session()
87+
88+
bf_df = cast(bpd.DataFrame, session.read_pandas(self._obj))
89+
result = bigframes.bigquery.ai.forecast(
90+
bf_df,
91+
data_col=data_col,
92+
timestamp_col=timestamp_col,
93+
model=model,
94+
id_cols=id_cols,
95+
horizon=horizon,
96+
confidence_level=confidence_level,
97+
context_window=context_window,
98+
output_historical_time_series=output_historical_time_series,
99+
)
100+
return result.to_pandas(ordered=True)
101+
102+
24103
@pandas.api.extensions.register_dataframe_accessor("bigquery")
25104
class BigQueryDataFrameAccessor:
26105
"""
@@ -32,6 +111,13 @@ class BigQueryDataFrameAccessor:
32111
def __init__(self, pandas_obj: pandas.DataFrame):
33112
self._obj = pandas_obj
34113

114+
@property
115+
def ai(self) -> "AIAccessor":
116+
"""
117+
Accessor for BigQuery AI functions.
118+
"""
119+
return AIAccessor(self._obj)
120+
35121
def sql_scalar(self, sql_template: str, *, output_dtype=None, session=None):
36122
"""
37123
Compute a new pandas Series by applying a SQL scalar function to the DataFrame.
Collapse file

‎tests/unit/core/compile/sqlglot/test_dataframe_accessor.py‎

Copy file name to clipboardExpand all lines: tests/unit/core/compile/sqlglot/test_dataframe_accessor.py
+39Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,42 @@ def to_pandas(series, *, ordered):
4343

4444
session.read_pandas.assert_called_once()
4545
snapshot.assert_match(result, "out.sql")
46+
47+
48+
def test_ai_forecast(snapshot, monkeypatch):
49+
import bigframes.bigquery.ai
50+
import bigframes.session
51+
52+
session = mock.create_autospec(bigframes.session.Session)
53+
bf_df = mock.create_autospec(bpd.DataFrame)
54+
session.read_pandas.return_value = bf_df
55+
56+
def mock_ai_forecast(df, **kwargs):
57+
assert df is bf_df
58+
result_df = mock.create_autospec(bpd.DataFrame)
59+
result_df.to_pandas.return_value = kwargs
60+
return result_df
61+
62+
import bigframes.bigquery.ai
63+
64+
monkeypatch.setattr(bigframes.bigquery.ai, "forecast", mock_ai_forecast)
65+
66+
df = pd.DataFrame({"date": ["2020-01-01"], "value": [1.0]})
67+
result = df.bigquery.ai.forecast(
68+
timestamp_col="date",
69+
data_col="value",
70+
horizon=5,
71+
session=session,
72+
)
73+
74+
session.read_pandas.assert_called_once()
75+
assert result == {
76+
"timestamp_col": "date",
77+
"data_col": "value",
78+
"model": "TimesFM 2.0",
79+
"id_cols": None,
80+
"horizon": 5,
81+
"confidence_level": 0.95,
82+
"context_window": None,
83+
"output_historical_time_series": False,
84+
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.