diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 5aad77a394..1e2224c9bc 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -126,8 +126,8 @@ def generate_text_embedding( ), ) - def forecast(self) -> bpd.DataFrame: - sql = self._model_manipulation_sql_generator.ml_forecast() + def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: + sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options) return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index() def evaluate(self, input_data: Optional[bpd.DataFrame] = None): diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index 995201062b..03b9857cc5 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -86,21 +86,38 @@ def _fit( options=self._bqml_options, ) - def predict(self, X=None) -> bpd.DataFrame: + def predict( + self, X=None, horizon: int = 3, confidence_level: float = 0.95 + ) -> bpd.DataFrame: """Predict the closest cluster for each sample in X. Args: X (default None): ignored, to be compatible with other APIs. + horizon (int, default: 3): + an int value that specifies the number of time points to forecast. + The default value is 3, and the maximum value is 1000. + confidence_level (float, default 0.95): + a float value that specifies percentage of the future values that fall in the prediction interval. + The valid input range is [0.0, 1.0). Returns: bigframes.dataframe.DataFrame: The predicted DataFrames. Which contains 2 columns "forecast_timestamp" and "forecast_value". """ + if horizon < 1 or horizon > 1000: + raise ValueError(f"horizon must be [1, 1000], but is {horizon}.") + if confidence_level < 0.0 or confidence_level >= 1.0: + raise ValueError( + f"confidence_level must be [0.0, 1.0), but is {confidence_level}." + ) + if not self._bqml_model: raise RuntimeError("A model must be fitted before predict") - return self._bqml_model.forecast() + return self._bqml_model.forecast( + options={"horizon": horizon, "confidence_level": confidence_level} + ) def score( self, diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 5fb40624dd..25caaf1ac6 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -223,9 +223,11 @@ def ml_predict(self, source_df: bpd.DataFrame) -> str: return f"""SELECT * FROM ML.PREDICT(MODEL `{self._model_name}`, ({self._source_sql(source_df)}))""" - def ml_forecast(self) -> str: + def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str: """Encode ML.FORECAST for BQML""" - return f"""SELECT * FROM ML.FORECAST(MODEL `{self._model_name}`)""" + struct_options_sql = self.struct_options(**struct_options) + return f"""SELECT * FROM ML.FORECAST(MODEL `{self._model_name}`, + {struct_options_sql})""" def ml_generate_text( self, source_df: bpd.DataFrame, struct_options: Mapping[str, Union[int, float]] diff --git a/tests/system/small/ml/test_core.py b/tests/system/small/ml/test_core.py index 22cbbb1932..915c4aa444 100644 --- a/tests/system/small/ml/test_core.py +++ b/tests/system/small/ml/test_core.py @@ -336,17 +336,18 @@ def test_model_generate_text( def test_model_forecast(time_series_bqml_arima_plus_model: core.BqmlModel): utc = pytz.utc - forecast = time_series_bqml_arima_plus_model.forecast().to_pandas()[ - ["forecast_timestamp", "forecast_value"] - ] + forecast = time_series_bqml_arima_plus_model.forecast( + {"horizon": 4, "confidence_level": 0.8} + ).to_pandas()[["forecast_timestamp", "forecast_value"]] expected = pd.DataFrame( { "forecast_timestamp": [ datetime(2017, 8, 2, tzinfo=utc), datetime(2017, 8, 3, tzinfo=utc), datetime(2017, 8, 4, tzinfo=utc), + datetime(2017, 8, 5, tzinfo=utc), ], - "forecast_value": [2724.472284, 2593.368389, 2353.613034], + "forecast_value": [2724.472284, 2593.368389, 2353.613034, 1781.623071], } ) expected["forecast_value"] = expected["forecast_value"].astype(pd.Float64Dtype()) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index 948db59650..be8d9c2bac 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -18,8 +18,10 @@ import pyarrow as pa import pytz +from bigframes.ml import forecasting -def test_model_predict(time_series_arima_plus_model): + +def test_model_predict_default(time_series_arima_plus_model: forecasting.ARIMAPlus): utc = pytz.utc predictions = time_series_arima_plus_model.predict().to_pandas() assert predictions.shape == (3, 8) @@ -47,7 +49,40 @@ def test_model_predict(time_series_arima_plus_model): ) -def test_model_score(time_series_arima_plus_model, new_time_series_df): +def test_model_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus): + utc = pytz.utc + predictions = time_series_arima_plus_model.predict( + horizon=4, confidence_level=0.9 + ).to_pandas() + assert predictions.shape == (4, 8) + result = predictions[["forecast_timestamp", "forecast_value"]] + expected = pd.DataFrame( + { + "forecast_timestamp": [ + datetime(2017, 8, 2, tzinfo=utc), + datetime(2017, 8, 3, tzinfo=utc), + datetime(2017, 8, 4, tzinfo=utc), + datetime(2017, 8, 5, tzinfo=utc), + ], + "forecast_value": [2724.472284, 2593.368389, 2353.613034, 1781.623071], + } + ) + expected["forecast_value"] = expected["forecast_value"].astype(pd.Float64Dtype()) + expected["forecast_timestamp"] = expected["forecast_timestamp"].astype( + pd.ArrowDtype(pa.timestamp("us", tz="UTC")) + ) + + pd.testing.assert_frame_equal( + result, + expected, + rtol=0.1, + check_index_type=False, + ) + + +def test_model_score( + time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df +): result = time_series_arima_plus_model.score( new_time_series_df[["parsed_date"]], new_time_series_df[["total_visits"]] ).to_pandas() @@ -69,7 +104,9 @@ def test_model_score(time_series_arima_plus_model, new_time_series_df): ) -def test_model_score_series(time_series_arima_plus_model, new_time_series_df): +def test_model_score_series( + time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df +): result = time_series_arima_plus_model.score( new_time_series_df["parsed_date"], new_time_series_df["total_visits"] ).to_pandas() diff --git a/tests/unit/ml/test_sql.py b/tests/unit/ml/test_sql.py index 9223058540..73d19cc0bb 100644 --- a/tests/unit/ml/test_sql.py +++ b/tests/unit/ml/test_sql.py @@ -293,6 +293,22 @@ def test_ml_centroids_produces_correct_sql( ) +def test_forecast_correct_sql( + model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, + mock_df: bpd.DataFrame, +): + sql = model_manipulation_sql_generator.ml_forecast( + struct_options={"option_key1": 1, "option_key2": 2.2}, + ) + assert ( + sql + == """SELECT * FROM ML.FORECAST(MODEL `my_project_id.my_dataset_id.my_model_id`, + STRUCT( + 1 AS option_key1, + 2.2 AS option_key2))""" + ) + + def test_ml_generate_text_produces_correct_sql( model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, mock_df: bpd.DataFrame,