diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 810fb1f7bd..2f3b532a74 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -123,6 +123,12 @@ def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame: self._model_manipulation_sql_generator.ml_predict, ) + def explain_predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame: + return self._apply_ml_tvf( + input_data, + self._model_manipulation_sql_generator.ml_explain_predict, + ) + def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame: return self._apply_ml_tvf( input_data, diff --git a/bigframes/ml/linear_model.py b/bigframes/ml/linear_model.py index ae4e1944cc..1a1a5e0ca0 100644 --- a/bigframes/ml/linear_model.py +++ b/bigframes/ml/linear_model.py @@ -160,6 +160,34 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame: return self._bqml_model.predict(X) + def predict_explain( + self, + X: utils.ArrayType, + ) -> bpd.DataFrame: + """ + Explain predictions for a linear regression model. + + .. note:: + Output matches that of the BigQuery ML.EXPLAIN_PREDICT function. + See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict + + Args: + X (bigframes.dataframe.DataFrame or bigframes.series.Series or + pandas.core.frame.DataFrame or pandas.core.series.Series): + Series or a DataFrame to explain its predictions. + + Returns: + bigframes.pandas.DataFrame: + The predicted DataFrames with explanation columns. + """ + # TODO(b/377366612): Add support for `top_k_features` parameter + if not self._bqml_model: + raise RuntimeError("A model must be fitted before predict") + + (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) + + return self._bqml_model.explain_predict(X) + def score( self, X: utils.ArrayType, diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 1ef43d9ce5..93b8a3a051 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -304,6 +304,11 @@ def ml_predict(self, source_sql: str) -> str: return f"""SELECT * FROM ML.PREDICT(MODEL {self._model_ref_sql()}, ({source_sql}))""" + def ml_explain_predict(self, source_sql: str) -> str: + """Encode ML.EXPLAIN_PREDICT for BQML""" + return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()}, + ({source_sql}))""" + def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str: """Encode ML.FORECAST for BQML""" struct_options_sql = self.struct_options(**struct_options) diff --git a/tests/system/small/ml/test_core.py b/tests/system/small/ml/test_core.py index 30b75f502d..3ea31353b1 100644 --- a/tests/system/small/ml/test_core.py +++ b/tests/system/small/ml/test_core.py @@ -261,6 +261,28 @@ def test_model_predict(penguins_bqml_linear_model: core.BqmlModel, new_penguins_ ) +def test_model_predict_explain( + penguins_bqml_linear_model: core.BqmlModel, new_penguins_df +): + predictions = penguins_bqml_linear_model.explain_predict( + new_penguins_df + ).to_pandas() + expected = pd.DataFrame( + { + "predicted_body_mass_g": [4030.1, 3280.8, 3177.9], + "approximation_error": [0.0, 0.0, 0.0], + }, + dtype="Float64", + index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"), + ) + pd.testing.assert_frame_equal( + predictions[["predicted_body_mass_g", "approximation_error"]].sort_index(), + expected, + check_exact=False, + rtol=0.1, + ) + + def test_model_predict_with_unnamed_index( penguins_bqml_linear_model: core.BqmlModel, new_penguins_df ): @@ -289,6 +311,39 @@ def test_model_predict_with_unnamed_index( ) +def test_model_predict_explain_with_unnamed_index( + penguins_bqml_linear_model: core.BqmlModel, new_penguins_df +): + # This will result in an index that lacks a name, which the ML library will + # need to persist through the call to ML.PREDICT + new_penguins_df = new_penguins_df.reset_index() + + # remove the middle tag number to ensure we're really keeping the unnamed index + new_penguins_df = typing.cast( + bigframes.dataframe.DataFrame, + new_penguins_df[new_penguins_df.tag_number != 1672], + ) + + predictions = penguins_bqml_linear_model.explain_predict( + new_penguins_df + ).to_pandas() + + expected = pd.DataFrame( + { + "predicted_body_mass_g": [4030.1, 3177.9], + "approximation_error": [0.0, 0.0], + }, + dtype="Float64", + index=pd.Index([0, 2], dtype="Int64"), + ) + pd.testing.assert_frame_equal( + predictions[["predicted_body_mass_g", "approximation_error"]].sort_index(), + expected, + check_exact=False, + rtol=0.1, + ) + + def test_model_detect_anomalies( penguins_bqml_pca_model: core.BqmlModel, new_penguins_df ): diff --git a/tests/system/small/ml/test_linear_model.py b/tests/system/small/ml/test_linear_model.py index 6d0a361f55..0832c559c1 100644 --- a/tests/system/small/ml/test_linear_model.py +++ b/tests/system/small/ml/test_linear_model.py @@ -16,6 +16,8 @@ import pandas import pytest +from bigframes.ml import linear_model + def test_linear_reg_model_score(penguins_linear_model, penguins_df_default_index): df = penguins_df_default_index.dropna() @@ -106,6 +108,72 @@ def test_linear_reg_model_predict(penguins_linear_model, new_penguins_df): ) +def test_linear_reg_model_predict_explain(penguins_linear_model, new_penguins_df): + predictions = penguins_linear_model.predict_explain(new_penguins_df).to_pandas() + assert predictions.shape == (3, 12) + result = predictions[["predicted_body_mass_g", "approximation_error"]] + expected = pandas.DataFrame( + { + "predicted_body_mass_g": [4030.1, 3280.8, 3177.9], + "approximation_error": [ + 0.0, + 0.0, + 0.0, + ], + }, + dtype="Float64", + index=pandas.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"), + ) + pandas.testing.assert_frame_equal( + result.sort_index(), + expected, + check_exact=False, + rtol=0.1, + ) + + +def test_linear_reg_model_predict_params( + penguins_linear_model: linear_model.LinearRegression, new_penguins_df +): + predictions = penguins_linear_model.predict(new_penguins_df).to_pandas() + assert predictions.shape[0] >= 1 + prediction_columns = set(predictions.columns) + expected_columns = { + "predicted_body_mass_g", + "species", + "island", + "culmen_length_mm", + "culmen_depth_mm", + "flipper_length_mm", + "body_mass_g", + "sex", + } + assert expected_columns <= prediction_columns + + +def test_linear_reg_model_predict_explain_params( + penguins_linear_model: linear_model.LinearRegression, new_penguins_df +): + predictions = penguins_linear_model.predict_explain(new_penguins_df).to_pandas() + assert predictions.shape[0] >= 1 + prediction_columns = set(predictions.columns) + expected_columns = { + "predicted_body_mass_g", + "top_feature_attributions", + "baseline_prediction_value", + "prediction_value", + "approximation_error", + "species", + "island", + "culmen_length_mm", + "culmen_depth_mm", + "flipper_length_mm", + "body_mass_g", + "sex", + } + assert expected_columns <= prediction_columns + + def test_to_gbq_saved_linear_reg_model_scores( penguins_linear_model, table_id_unique, penguins_df_default_index ): diff --git a/tests/unit/ml/test_sql.py b/tests/unit/ml/test_sql.py index ee0821dfe9..9d18649efe 100644 --- a/tests/unit/ml/test_sql.py +++ b/tests/unit/ml/test_sql.py @@ -342,6 +342,18 @@ def test_ml_predict_correct( ) +def test_ml_explain_predict_correct( + model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, + mock_df: bpd.DataFrame, +): + sql = model_manipulation_sql_generator.ml_explain_predict(source_sql=mock_df.sql) + assert ( + sql + == """SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, + (input_X_y_sql))""" + ) + + def test_ml_llm_evaluate_correct( model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, mock_df: bpd.DataFrame,