diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 7b4638157e..168bc584f7 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -205,6 +205,11 @@ def arima_evaluate(self, show_all_candidate_models: bool = False): return self._session.read_gbq(sql) + def arima_coefficients(self) -> bpd.DataFrame: + sql = self._model_manipulation_sql_generator.ml_arima_coefficients() + + return self._session.read_gbq(sql) + def centroids(self) -> bpd.DataFrame: assert self._model.model_type == "KMEANS" diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index 5bd01c8826..783e7741b8 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -269,6 +269,27 @@ def predict( options={"horizon": horizon, "confidence_level": confidence_level} ) + @property + def coef_( + self, + ) -> bpd.DataFrame: + """Inspect the coefficients of the model. + + ..note:: + + Output matches that of the ML.ARIMA_COEFFICIENTS function. + See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-arima-coefficients + for the outputs relevant to this model type. + + Returns: + bigframes.dataframe.DataFrame: + A DataFrame with the coefficients for the model. + """ + + if not self._bqml_model: + raise RuntimeError("A model must be fitted before inspect coefficients") + return self._bqml_model.arima_coefficients() + def detect_anomalies( self, X: Union[bpd.DataFrame, bpd.Series], diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 3679be16c6..ea693e3437 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -318,6 +318,10 @@ def ml_evaluate(self, source_df: Optional[bpd.DataFrame] = None) -> str: return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`, ({source_sql}))""" + def ml_arima_coefficients(self) -> str: + """Encode ML.ARIMA_COEFFICIENTS for BQML""" + return f"""SELECT * FROM ML.ARIMA_COEFFICIENTS(MODEL `{self._model_name}`)""" + # ML evaluation TVFs def ml_llm_evaluate( self, source_df: bpd.DataFrame, task_type: Optional[str] = None diff --git a/tests/system/large/ml/test_forecasting.py b/tests/system/large/ml/test_forecasting.py index b333839e2e..ef74398c2e 100644 --- a/tests/system/large/ml/test_forecasting.py +++ b/tests/system/large/ml/test_forecasting.py @@ -13,6 +13,7 @@ # limitations under the License. import pandas as pd +import pytest from bigframes.ml import forecasting @@ -31,15 +32,22 @@ ] -def test_arima_plus_model_fit_score( - time_series_df_default_index, dataset_id, new_time_series_df -): +@pytest.fixture(scope="module") +def arima_model(time_series_df_default_index): model = forecasting.ARIMAPlus() X_train = time_series_df_default_index[["parsed_date"]] y_train = time_series_df_default_index[["total_visits"]] model.fit(X_train, y_train) + return model + + +def test_arima_plus_model_fit_score( + dataset_id, + new_time_series_df, + arima_model, +): - result = model.score( + result = arima_model.score( new_time_series_df[["parsed_date"]], new_time_series_df[["total_visits"]] ).to_pandas() expected = pd.DataFrame( @@ -56,29 +64,39 @@ def test_arima_plus_model_fit_score( pd.testing.assert_frame_equal(result, expected, check_exact=False, rtol=0.1) # save, load to ensure configuration was kept - reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True) + reloaded_model = arima_model.to_gbq( + f"{dataset_id}.temp_arima_plus_model", replace=True + ) assert ( f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name ) -def test_arima_plus_model_fit_summary(time_series_df_default_index, dataset_id): - model = forecasting.ARIMAPlus() - X_train = time_series_df_default_index[["parsed_date"]] - y_train = time_series_df_default_index[["total_visits"]] - model.fit(X_train, y_train) +def test_arima_plus_model_fit_summary(dataset_id, arima_model): - result = model.summary() + result = arima_model.summary() assert result.shape == (1, 12) assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL) # save, load to ensure configuration was kept - reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True) + reloaded_model = arima_model.to_gbq( + f"{dataset_id}.temp_arima_plus_model", replace=True + ) assert ( f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name ) +def test_arima_coefficients(arima_model): + got = arima_model.coef_ + expected_columns = { + "ar_coefficients", + "ma_coefficients", + "intercept_or_drift", + } + assert set(got.columns) == expected_columns + + def test_arima_plus_model_fit_params(time_series_df_default_index, dataset_id): model = forecasting.ARIMAPlus( horizon=100, diff --git a/tests/unit/ml/test_sql.py b/tests/unit/ml/test_sql.py index 1a5e8fe962..4dd90b2c4a 100644 --- a/tests/unit/ml/test_sql.py +++ b/tests/unit/ml/test_sql.py @@ -47,6 +47,16 @@ def mock_df(): return mock_df +def test_ml_arima_coefficients( + model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, +): + sql = model_manipulation_sql_generator.ml_arima_coefficients() + assert ( + sql + == """SELECT * FROM ML.ARIMA_COEFFICIENTS(MODEL `my_project_id.my_dataset_id.my_model_id`)""" + ) + + def test_options_correct(base_sql_generator: ml_sql.BaseSqlGenerator): sql = base_sql_generator.options( model_type="lin_reg", input_label_cols=["col_a"], l1_reg=0.6