From c7a41ff4a8a81ac42700a8a45ccb8a92dff6212a Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Mon, 14 Oct 2024 09:32:05 +0000 Subject: [PATCH 1/6] feat: support data split method in `LinearRegression` --- bigframes/ml/linear_model.py | 21 ++++++- tests/system/large/ml/test_linear_model.py | 70 ++++++++++++++++++++++ tests/unit/ml/test_golden_sql.py | 2 +- 3 files changed, 91 insertions(+), 2 deletions(-) diff --git a/bigframes/ml/linear_model.py b/bigframes/ml/linear_model.py index 8fe1d6ec27..3f779759e9 100644 --- a/bigframes/ml/linear_model.py +++ b/bigframes/ml/linear_model.py @@ -42,6 +42,9 @@ "warm_start": "warmStart", "calculate_p_values": "calculatePValues", "enable_global_explain": "enableGlobalExplain", + "data_split_method": "dataSplitMethod", + "data_split_eval_fraction": "dataSplitEvalFraction", + "data_split_col": "dataSplitColumn", } @@ -55,6 +58,15 @@ class LinearRegression( def __init__( self, *, + data_split_method: Literal[ + "auto_split", + "random", + "custom", + "seq", + "no_split", + ] = "no_split", + data_split_eval_fraction: Optional[float] = None, + data_split_col: Optional[str] = None, optimize_strategy: Literal[ "auto_strategy", "batch_gradient_descent", "normal_equation" ] = "auto_strategy", @@ -70,6 +82,9 @@ def __init__( calculate_p_values: bool = False, enable_global_explain: bool = False, ): + self.data_split_method = data_split_method + self.data_split_eval_fraction = data_split_eval_fraction + self.data_split_col = data_split_col self.optimize_strategy = optimize_strategy self.fit_intercept = fit_intercept self.l1_reg = l1_reg @@ -104,7 +119,7 @@ def _bqml_options(self) -> dict: """The model options as they will be set for BQML""" options = { "model_type": "LINEAR_REG", - "data_split_method": "NO_SPLIT", + "data_split_method": self.data_split_method, "optimize_strategy": self.optimize_strategy, "fit_intercept": self.fit_intercept, "l2_reg": self.l2_reg, @@ -114,6 +129,10 @@ def _bqml_options(self) -> dict: "calculate_p_values": self.calculate_p_values, "enable_global_explain": self.enable_global_explain, } + if self.data_split_eval_fraction is not None: + options["data_split_eval_fraction"] = self.data_split_eval_fraction + if self.data_split_col is not None: + options["data_split_col"] = self.data_split_col if self.l1_reg is not None: options["l1_reg"] = self.l1_reg if self.learning_rate is not None: diff --git a/tests/system/large/ml/test_linear_model.py b/tests/system/large/ml/test_linear_model.py index f593ac2983..bda9ffe928 100644 --- a/tests/system/large/ml/test_linear_model.py +++ b/tests/system/large/ml/test_linear_model.py @@ -56,6 +56,72 @@ def test_linear_regression_configure_fit_score(penguins_df_default_index, datase assert reloaded_model.ls_init_learning_rate is None assert reloaded_model.max_iterations == 20 assert reloaded_model.tol == 0.01 + assert reloaded_model.data_split_method == "NO_SPLIT" + + +def test_linear_regression_custom_split_fit_score( + penguins_df_default_index, dataset_id +): + + import random + + import bigframes.dtypes + import bigframes.series + + penguins_eval_split_col = "penguins_eval_split_col" + + model = bigframes.ml.linear_model.LinearRegression( + data_split_method="custom", + data_split_col=penguins_eval_split_col, + ) + + df = penguins_df_default_index.dropna() + X_train = df[ + [ + "species", + "island", + "culmen_length_mm", + "culmen_depth_mm", + "flipper_length_mm", + "sex", + ] + ] + X_train[penguins_eval_split_col] = bigframes.series.Series( + [ + random.choice([False, False, True, bigframes.dtypes.pd.NA]) + for i in range(len(X_train)) + ], + dtype=bigframes.dtypes.BOOL_DTYPE, + session=X_train._session, + ) + y_train = df[["body_mass_g"]] + model.fit(X_train, y_train) + + # Check score to ensure the model was fitted + result = model.score(X_train, y_train).to_pandas() + utils.check_pandas_df_schema_and_index( + result, columns=utils.ML_REGRESSION_METRICS, index=1 + ) + + # save, load, check parameters to ensure configuration was kept + reloaded_model = model.to_gbq(f"{dataset_id}.temp_configured_model", replace=True) + assert reloaded_model._bqml_model is not None + assert ( + f"{dataset_id}.temp_configured_model" in reloaded_model._bqml_model.model_name + ) + assert reloaded_model.optimize_strategy == "NORMAL_EQUATION" + assert reloaded_model.fit_intercept is True + assert reloaded_model.calculate_p_values is False + assert reloaded_model.enable_global_explain is False + assert reloaded_model.l1_reg is None + assert reloaded_model.l2_reg == 0.0 + assert reloaded_model.learning_rate is None + assert reloaded_model.learning_rate_strategy == "line_search" + assert reloaded_model.ls_init_learning_rate is None + assert reloaded_model.max_iterations == 20 + assert reloaded_model.tol == 0.01 + assert reloaded_model.data_split_method == "CUSTOM" + assert reloaded_model.data_split_col == penguins_eval_split_col def test_linear_regression_customized_params_fit_score( @@ -70,6 +136,8 @@ def test_linear_regression_customized_params_fit_score( optimize_strategy="batch_gradient_descent", learning_rate_strategy="constant", learning_rate=0.2, + data_split_method="random", + data_split_eval_fraction=0.1, ) df = penguins_df_default_index.dropna() @@ -109,6 +177,8 @@ def test_linear_regression_customized_params_fit_score( assert reloaded_model.tol == 0.02 assert reloaded_model.learning_rate_strategy == "CONSTANT" assert reloaded_model.learning_rate == 0.2 + assert reloaded_model.data_split_method == "RANDOM" + assert reloaded_model.data_split_eval_fraction == 0.1 def test_unordered_mode_linear_regression_configure_fit_score_predict( diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index aa7e919b24..ff1a5a3c27 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -106,7 +106,7 @@ def test_linear_regression_default_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="auto_strategy",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' + 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="no_split",\n optimize_strategy="auto_strategy",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) From 98b9f85616b82be1b853493d18e28a8e1895c77e Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Mon, 14 Oct 2024 09:42:17 +0000 Subject: [PATCH 2/6] place new params after the existing ones --- bigframes/ml/linear_model.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/bigframes/ml/linear_model.py b/bigframes/ml/linear_model.py index 3f779759e9..3edbfa6f52 100644 --- a/bigframes/ml/linear_model.py +++ b/bigframes/ml/linear_model.py @@ -58,15 +58,6 @@ class LinearRegression( def __init__( self, *, - data_split_method: Literal[ - "auto_split", - "random", - "custom", - "seq", - "no_split", - ] = "no_split", - data_split_eval_fraction: Optional[float] = None, - data_split_col: Optional[str] = None, optimize_strategy: Literal[ "auto_strategy", "batch_gradient_descent", "normal_equation" ] = "auto_strategy", @@ -81,10 +72,16 @@ def __init__( ls_init_learning_rate: Optional[float] = None, calculate_p_values: bool = False, enable_global_explain: bool = False, + data_split_method: Literal[ + "auto_split", + "random", + "custom", + "seq", + "no_split", + ] = "no_split", + data_split_eval_fraction: Optional[float] = None, + data_split_col: Optional[str] = None, ): - self.data_split_method = data_split_method - self.data_split_eval_fraction = data_split_eval_fraction - self.data_split_col = data_split_col self.optimize_strategy = optimize_strategy self.fit_intercept = fit_intercept self.l1_reg = l1_reg @@ -97,6 +94,9 @@ def __init__( self.ls_init_learning_rate = ls_init_learning_rate self.calculate_p_values = calculate_p_values self.enable_global_explain = enable_global_explain + self.data_split_method = data_split_method + self.data_split_eval_fraction = data_split_eval_fraction + self.data_split_col = data_split_col self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() @@ -129,10 +129,6 @@ def _bqml_options(self) -> dict: "calculate_p_values": self.calculate_p_values, "enable_global_explain": self.enable_global_explain, } - if self.data_split_eval_fraction is not None: - options["data_split_eval_fraction"] = self.data_split_eval_fraction - if self.data_split_col is not None: - options["data_split_col"] = self.data_split_col if self.l1_reg is not None: options["l1_reg"] = self.l1_reg if self.learning_rate is not None: @@ -142,6 +138,10 @@ def _bqml_options(self) -> dict: # Even presenting warm_start returns error for NORMAL_EQUATION optimizer if self.warm_start: options["warm_start"] = self.warm_start + if self.data_split_eval_fraction is not None: + options["data_split_eval_fraction"] = self.data_split_eval_fraction + if self.data_split_col is not None: + options["data_split_col"] = self.data_split_col return options From 70e17a58aee8dd206398ba96c3c03b157da4266e Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Mon, 14 Oct 2024 18:57:39 +0000 Subject: [PATCH 3/6] fix unit test to expect no_split instead of NO_SPLIT --- tests/unit/ml/test_golden_sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index 9f1b5fffe8..fde17e07ad 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -116,7 +116,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X, model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='no_split',\n optimize_strategy='auto_strategy',\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" ) From 5211186973a8e31b3388208d396615ccd8662f89 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Mon, 14 Oct 2024 23:15:59 +0000 Subject: [PATCH 4/6] update documentation for LinearRegression --- .../sklearn/linear_model/_base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/third_party/bigframes_vendored/sklearn/linear_model/_base.py b/third_party/bigframes_vendored/sklearn/linear_model/_base.py index 69f98697af..a3229e8e08 100644 --- a/third_party/bigframes_vendored/sklearn/linear_model/_base.py +++ b/third_party/bigframes_vendored/sklearn/linear_model/_base.py @@ -91,6 +91,21 @@ class LinearRegression(RegressorMixin, LinearModel): Specifies whether to compute p-values and standard errors during training. Default to False. enable_global_explain (bool, default False): Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False. + data_split_method (str, default "no_split"): + The method to split input data into training and evaluation sets. + Possible values are "auto_split", "random", "custom", "seq" and + "no_split". Default to "no_split". For details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_method). + data_split_eval_fraction (float or None, default None): + Specifies the fraction of the data used for evaluation. Accurate to + two decimal places. Default to None, in which all the data would be + used for training and evaluation. For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_eval_fraction). + data_split_col (str or None, default None): + Identifies the column used to split the data when + ``data_split_method`` is set to "custom" or "seq". Default to None. + For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_col). """ def fit( From 32a6f71598411cd9441ada789e86ba7cf3b8b2e5 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Tue, 15 Oct 2024 02:05:37 +0000 Subject: [PATCH 5/6] support data split in LogisticRegression model --- bigframes/ml/linear_model.py | 18 +++++- tests/system/large/ml/test_linear_model.py | 63 +++++++++++++++++++ tests/unit/ml/test_golden_sql.py | 4 +- .../sklearn/linear_model/_logistic.py | 15 +++++ 4 files changed, 97 insertions(+), 3 deletions(-) diff --git a/bigframes/ml/linear_model.py b/bigframes/ml/linear_model.py index 3edbfa6f52..93db6fc34b 100644 --- a/bigframes/ml/linear_model.py +++ b/bigframes/ml/linear_model.py @@ -228,6 +228,15 @@ def __init__( calculate_p_values: bool = False, enable_global_explain: bool = False, class_weight: Optional[Union[Literal["balanced"], Dict[str, float]]] = None, + data_split_method: Literal[ + "auto_split", + "random", + "custom", + "seq", + "no_split", + ] = "no_split", + data_split_eval_fraction: Optional[float] = None, + data_split_col: Optional[str] = None, ): self.optimize_strategy = optimize_strategy self.fit_intercept = fit_intercept @@ -242,6 +251,9 @@ def __init__( self.calculate_p_values = calculate_p_values self.enable_global_explain = enable_global_explain self.class_weight = class_weight + self.data_split_method = data_split_method + self.data_split_eval_fraction = data_split_eval_fraction + self.data_split_col = data_split_col self._auto_class_weight = class_weight == "balanced" self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() @@ -272,7 +284,7 @@ def _bqml_options(self) -> dict: """The model options as they will be set for BQML""" options = { "model_type": "LOGISTIC_REG", - "data_split_method": "NO_SPLIT", + "data_split_method": self.data_split_method, "fit_intercept": self.fit_intercept, "auto_class_weights": self._auto_class_weight, "optimize_strategy": self.optimize_strategy, @@ -294,6 +306,10 @@ def _bqml_options(self) -> dict: # Even presenting warm_start returns error for NORMAL_EQUATION optimizer if self.warm_start: options["warm_start"] = self.warm_start + if self.data_split_eval_fraction is not None: + options["data_split_eval_fraction"] = self.data_split_eval_fraction + if self.data_split_col is not None: + options["data_split_col"] = self.data_split_col return options diff --git a/tests/system/large/ml/test_linear_model.py b/tests/system/large/ml/test_linear_model.py index bda9ffe928..e910bc7c4f 100644 --- a/tests/system/large/ml/test_linear_model.py +++ b/tests/system/large/ml/test_linear_model.py @@ -270,6 +270,65 @@ def test_logistic_regression_configure_fit_score(penguins_df_default_index, data ) assert reloaded_model.fit_intercept is True assert reloaded_model.class_weight is None + assert reloaded_model.data_split_method == "NO_SPLIT" + + +def test_logistic_regression_custom_split_fit_score( + penguins_df_default_index, dataset_id +): + import random + + import bigframes.dtypes + import bigframes.series + + penguins_eval_split_col = "penguins_eval_split_col" + + model = bigframes.ml.linear_model.LogisticRegression( + data_split_method="custom", + data_split_col=penguins_eval_split_col, + ) + + df = penguins_df_default_index.dropna() + X_train = df[ + [ + "species", + "island", + "culmen_length_mm", + "culmen_depth_mm", + "flipper_length_mm", + "body_mass_g", + ] + ] + X_train[penguins_eval_split_col] = bigframes.series.Series( + [ + random.choice([False, False, True, bigframes.dtypes.pd.NA]) + for i in range(len(X_train)) + ], + dtype=bigframes.dtypes.BOOL_DTYPE, + session=X_train._session, + ) + y_train = df[["sex"]] + model.fit(X_train, y_train) + + # Check score to ensure the model was fitted + result = model.score(X_train, y_train).to_pandas() + utils.check_pandas_df_schema_and_index( + result, columns=utils.ML_CLASSFICATION_METRICS, index=1 + ) + + # save, load, check parameters to ensure configuration was kept + reloaded_model = model.to_gbq( + f"{dataset_id}.temp_configured_logistic_reg_model", replace=True + ) + assert reloaded_model._bqml_model is not None + assert ( + f"{dataset_id}.temp_configured_logistic_reg_model" + in reloaded_model._bqml_model.model_name + ) + assert reloaded_model.fit_intercept is True + assert reloaded_model.class_weight is None + assert reloaded_model.data_split_method == "CUSTOM" + assert reloaded_model.data_split_col == penguins_eval_split_col def test_logistic_regression_customized_params_fit_score( @@ -285,6 +344,8 @@ def test_logistic_regression_customized_params_fit_score( optimize_strategy="batch_gradient_descent", learning_rate_strategy="constant", learning_rate=0.2, + data_split_method="random", + data_split_eval_fraction=0.1, ) df = penguins_df_default_index.dropna() X_train = df[ @@ -325,3 +386,5 @@ def test_logistic_regression_customized_params_fit_score( assert reloaded_model.tol == 0.02 assert reloaded_model.learning_rate_strategy == "CONSTANT" assert reloaded_model.learning_rate == 0.2 + assert reloaded_model.data_split_method == "RANDOM" + assert reloaded_model.data_split_eval_fraction == 0.1 diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index fde17e07ad..7192789ec7 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -149,7 +149,7 @@ def test_logistic_regression_default_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='no_split',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" ) @@ -171,7 +171,7 @@ def test_logistic_regression_params_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='no_split',\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" ) diff --git a/third_party/bigframes_vendored/sklearn/linear_model/_logistic.py b/third_party/bigframes_vendored/sklearn/linear_model/_logistic.py index c52a37018c..7ebdb8fab0 100644 --- a/third_party/bigframes_vendored/sklearn/linear_model/_logistic.py +++ b/third_party/bigframes_vendored/sklearn/linear_model/_logistic.py @@ -61,6 +61,21 @@ class LogisticRegression(LinearClassifierMixin, BaseEstimator): Specifies whether to compute p-values and standard errors during training. Default to False. enable_global_explain (bool, default False): Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False. + data_split_method (str, default "no_split"): + The method to split input data into training and evaluation sets. + Possible values are "auto_split", "random", "custom", "seq" and + "no_split". Default to "no_split". For details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_method). + data_split_eval_fraction (float or None, default None): + Specifies the fraction of the data used for evaluation. Accurate to + two decimal places. Default to None, in which all the data would be + used for training and evaluation. For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_eval_fraction). + data_split_col (str or None, default None): + Identifies the column used to split the data when + ``data_split_method`` is set to "custom" or "seq". Default to None. + For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_col). """ def fit( From 542d593f25a4bd2bf7a59e630ade2b1e97528c2e Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Tue, 15 Oct 2024 07:33:46 +0000 Subject: [PATCH 6/6] support data split in xgb and random forest models --- bigframes/ml/ensemble.py | 105 ++++++++++++++-- tests/unit/ml/test_golden_sql.py | 115 +++++++++++++++++- .../sklearn/ensemble/_forest.py | 32 ++++- .../bigframes_vendored/xgboost/sklearn.py | 30 +++++ 4 files changed, 264 insertions(+), 18 deletions(-) diff --git a/bigframes/ml/ensemble.py b/bigframes/ml/ensemble.py index 0194d768b8..bd594ec0e9 100644 --- a/bigframes/ml/ensemble.py +++ b/bigframes/ml/ensemble.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union import bigframes_vendored.sklearn.ensemble._forest import bigframes_vendored.xgboost.sklearn @@ -47,6 +47,9 @@ "max_iterations": "maxIterations", "enable_global_explain": "enableGlobalExplain", "xgboost_version": "xgboostVersion", + "data_split_method": "dataSplitMethod", + "data_split_eval_fraction": "dataSplitEvalFraction", + "data_split_col": "dataSplitColumn", } @@ -78,6 +81,15 @@ def __init__( tol: float = 0.01, enable_global_explain: bool = False, xgboost_version: Literal["0.9", "1.1"] = "0.9", + data_split_method: Literal[ + "auto_split", + "random", + "custom", + "seq", + "no_split", + ] = "no_split", + data_split_eval_fraction: Optional[float] = None, + data_split_col: Optional[str] = None, ): self.n_estimators = n_estimators self.booster = booster @@ -97,6 +109,9 @@ def __init__( self.tol = tol self.enable_global_explain = enable_global_explain self.xgboost_version = xgboost_version + self.data_split_method = data_split_method + self.data_split_eval_fraction = data_split_eval_fraction + self.data_split_col = data_split_col self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() @@ -115,11 +130,11 @@ def _from_bq( return model @property - def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: + def _bqml_options(self) -> dict: """The model options as they will be set for BQML""" - return { + options = { "model_type": "BOOSTED_TREE_REGRESSOR", - "data_split_method": "NO_SPLIT", + "data_split_method": self.data_split_method, "early_stop": True, "num_parallel_tree": self.n_estimators, "booster_type": self.booster, @@ -140,6 +155,13 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: "xgboost_version": self.xgboost_version, } + if self.data_split_eval_fraction is not None: + options["data_split_eval_fraction"] = self.data_split_eval_fraction + if self.data_split_col is not None: + options["data_split_col"] = self.data_split_col + + return options + def _fit( self, X: Union[bpd.DataFrame, bpd.Series], @@ -227,6 +249,15 @@ def __init__( tol: float = 0.01, enable_global_explain: bool = False, xgboost_version: Literal["0.9", "1.1"] = "0.9", + data_split_method: Literal[ + "auto_split", + "random", + "custom", + "seq", + "no_split", + ] = "no_split", + data_split_eval_fraction: Optional[float] = None, + data_split_col: Optional[str] = None, ): self.n_estimators = n_estimators self.booster = booster @@ -246,6 +277,9 @@ def __init__( self.tol = tol self.enable_global_explain = enable_global_explain self.xgboost_version = xgboost_version + self.data_split_method = data_split_method + self.data_split_eval_fraction = data_split_eval_fraction + self.data_split_col = data_split_col self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() @@ -264,11 +298,11 @@ def _from_bq( return model @property - def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: + def _bqml_options(self) -> dict: """The model options as they will be set for BQML""" - return { + options = { "model_type": "BOOSTED_TREE_CLASSIFIER", - "data_split_method": "NO_SPLIT", + "data_split_method": self.data_split_method, "early_stop": True, "num_parallel_tree": self.n_estimators, "booster_type": self.booster, @@ -289,6 +323,13 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: "xgboost_version": self.xgboost_version, } + if self.data_split_eval_fraction is not None: + options["data_split_eval_fraction"] = self.data_split_eval_fraction + if self.data_split_col is not None: + options["data_split_col"] = self.data_split_col + + return options + def _fit( self, X: Union[bpd.DataFrame, bpd.Series], @@ -370,6 +411,15 @@ def __init__( tol: float = 0.01, enable_global_explain: bool = False, xgboost_version: Literal["0.9", "1.1"] = "0.9", + data_split_method: Literal[ + "auto_split", + "random", + "custom", + "seq", + "no_split", + ] = "no_split", + data_split_eval_fraction: Optional[float] = None, + data_split_col: Optional[str] = None, ): self.n_estimators = n_estimators self.tree_method = tree_method @@ -385,6 +435,9 @@ def __init__( self.tol = tol self.enable_global_explain = enable_global_explain self.xgboost_version = xgboost_version + self.data_split_method = data_split_method + self.data_split_eval_fraction = data_split_eval_fraction + self.data_split_col = data_split_col self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() @@ -403,9 +456,9 @@ def _from_bq( return model @property - def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: + def _bqml_options(self) -> dict: """The model options as they will be set for BQML""" - return { + options = { "model_type": "RANDOM_FOREST_REGRESSOR", "early_stop": True, "num_parallel_tree": self.n_estimators, @@ -420,11 +473,18 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: "l1_reg": self.reg_alpha, "l2_reg": self.reg_lambda, "min_rel_progress": self.tol, - "data_split_method": "NO_SPLIT", + "data_split_method": self.data_split_method, "enable_global_explain": self.enable_global_explain, "xgboost_version": self.xgboost_version, } + if self.data_split_eval_fraction is not None: + options["data_split_eval_fraction"] = self.data_split_eval_fraction + if self.data_split_col is not None: + options["data_split_col"] = self.data_split_col + + return options + def _fit( self, X: Union[bpd.DataFrame, bpd.Series], @@ -526,6 +586,15 @@ def __init__( tol: float = 0.01, enable_global_explain: bool = False, xgboost_version: Literal["0.9", "1.1"] = "0.9", + data_split_method: Literal[ + "auto_split", + "random", + "custom", + "seq", + "no_split", + ] = "no_split", + data_split_eval_fraction: Optional[float] = None, + data_split_col: Optional[str] = None, ): self.n_estimators = n_estimators self.tree_method = tree_method @@ -541,6 +610,9 @@ def __init__( self.tol = tol self.enable_global_explain = enable_global_explain self.xgboost_version = xgboost_version + self.data_split_method = data_split_method + self.data_split_eval_fraction = data_split_eval_fraction + self.data_split_col = data_split_col self._bqml_model: Optional[core.BqmlModel] = None self._bqml_model_factory = globals.bqml_model_factory() @@ -559,9 +631,9 @@ def _from_bq( return model @property - def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: + def _bqml_options(self) -> dict: """The model options as they will be set for BQML""" - return { + options = { "model_type": "RANDOM_FOREST_CLASSIFIER", "early_stop": True, "num_parallel_tree": self.n_estimators, @@ -576,11 +648,18 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]: "l1_reg": self.reg_alpha, "l2_reg": self.reg_lambda, "min_rel_progress": self.tol, - "data_split_method": "NO_SPLIT", + "data_split_method": self.data_split_method, "enable_global_explain": self.enable_global_explain, "xgboost_version": self.xgboost_version, } + if self.data_split_eval_fraction is not None: + options["data_split_eval_fraction"] = self.data_split_eval_fraction + if self.data_split_col is not None: + options["data_split_col"] = self.data_split_col + + return options + def _fit( self, X: Union[bpd.DataFrame, bpd.Series], diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index 7192789ec7..accd4887b4 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -20,7 +20,7 @@ import pytest_mock import bigframes -from bigframes.ml import core, linear_model +from bigframes.ml import core, ensemble, linear_model import bigframes.pandas as bpd TEMP_MODEL_ID = bigquery.ModelReference.from_string( @@ -111,12 +111,14 @@ def test_linear_regression_default_fit( def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X, mock_y): - model = linear_model.LinearRegression(fit_intercept=False) + model = linear_model.LinearRegression( + fit_intercept=False, data_split_method="auto_split" + ) model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='no_split',\n optimize_strategy='auto_strategy',\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='auto_split',\n optimize_strategy='auto_strategy',\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" ) @@ -166,12 +168,13 @@ def test_logistic_regression_params_fit( optimize_strategy="batch_gradient_descent", learning_rate_strategy="constant", learning_rate=0.2, + data_split_method="auto_split", ) model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='no_split',\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='auto_split',\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" ) @@ -194,3 +197,107 @@ def test_logistic_regression_score(mock_session, bqml_model, mock_X, mock_y): mock_session.read_gbq.assert_called_once_with( "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))" ) + + +def test_xgb_regressor_default_fit(bqml_model_factory, mock_session, mock_X, mock_y): + model = ensemble.XGBRegressor() + model._bqml_model_factory = bqml_model_factory + model.fit(mock_X, mock_y) + + mock_session._start_query_ml_ddl.assert_called_once_with( + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='BOOSTED_TREE_REGRESSOR',\n data_split_method='no_split',\n early_stop=True,\n num_parallel_tree=1,\n booster_type='gbtree',\n tree_method='auto',\n min_tree_child_weight=1,\n colsample_bytree=1.0,\n colsample_bylevel=1.0,\n colsample_bynode=1.0,\n min_split_loss=0.0,\n max_tree_depth=6,\n subsample=1.0,\n l1_reg=0.0,\n l2_reg=1.0,\n learn_rate=0.3,\n max_iterations=20,\n min_rel_progress=0.01,\n enable_global_explain=False,\n xgboost_version='0.9',\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + ) + + +def test_xgb_regressor_params_fit(bqml_model_factory, mock_session, mock_X, mock_y): + model = ensemble.XGBRegressor( + data_split_method="seq", + data_split_eval_fraction=0.2, + data_split_col="split_col", + ) + model._bqml_model_factory = bqml_model_factory + model.fit(mock_X, mock_y) + + mock_session._start_query_ml_ddl.assert_called_once_with( + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='BOOSTED_TREE_REGRESSOR',\n data_split_method='seq',\n early_stop=True,\n num_parallel_tree=1,\n booster_type='gbtree',\n tree_method='auto',\n min_tree_child_weight=1,\n colsample_bytree=1.0,\n colsample_bylevel=1.0,\n colsample_bynode=1.0,\n min_split_loss=0.0,\n max_tree_depth=6,\n subsample=1.0,\n l1_reg=0.0,\n l2_reg=1.0,\n learn_rate=0.3,\n max_iterations=20,\n min_rel_progress=0.01,\n enable_global_explain=False,\n xgboost_version='0.9',\n data_split_eval_fraction=0.2,\n data_split_col='split_col',\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + ) + + +def test_xgb_classifier_default_fit(bqml_model_factory, mock_session, mock_X, mock_y): + model = ensemble.XGBClassifier() + model._bqml_model_factory = bqml_model_factory + model.fit(mock_X, mock_y) + + mock_session._start_query_ml_ddl.assert_called_once_with( + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='BOOSTED_TREE_CLASSIFIER',\n data_split_method='no_split',\n early_stop=True,\n num_parallel_tree=1,\n booster_type='gbtree',\n tree_method='auto',\n min_tree_child_weight=1,\n colsample_bytree=1.0,\n colsample_bylevel=1.0,\n colsample_bynode=1.0,\n min_split_loss=0.0,\n max_tree_depth=6,\n subsample=1.0,\n l1_reg=0.0,\n l2_reg=1.0,\n learn_rate=0.3,\n max_iterations=20,\n min_rel_progress=0.01,\n enable_global_explain=False,\n xgboost_version='0.9',\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + ) + + +def test_xgb_classifier_params_fit(bqml_model_factory, mock_session, mock_X, mock_y): + model = ensemble.XGBClassifier( + data_split_method="seq", + data_split_eval_fraction=0.2, + data_split_col="split_col", + ) + model._bqml_model_factory = bqml_model_factory + model.fit(mock_X, mock_y) + + mock_session._start_query_ml_ddl.assert_called_once_with( + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='BOOSTED_TREE_CLASSIFIER',\n data_split_method='seq',\n early_stop=True,\n num_parallel_tree=1,\n booster_type='gbtree',\n tree_method='auto',\n min_tree_child_weight=1,\n colsample_bytree=1.0,\n colsample_bylevel=1.0,\n colsample_bynode=1.0,\n min_split_loss=0.0,\n max_tree_depth=6,\n subsample=1.0,\n l1_reg=0.0,\n l2_reg=1.0,\n learn_rate=0.3,\n max_iterations=20,\n min_rel_progress=0.01,\n enable_global_explain=False,\n xgboost_version='0.9',\n data_split_eval_fraction=0.2,\n data_split_col='split_col',\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + ) + + +def test_randomforest_regressor_default_fit( + bqml_model_factory, mock_session, mock_X, mock_y +): + model = ensemble.RandomForestRegressor() + model._bqml_model_factory = bqml_model_factory + model.fit(mock_X, mock_y) + + mock_session._start_query_ml_ddl.assert_called_once_with( + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='RANDOM_FOREST_REGRESSOR',\n early_stop=True,\n num_parallel_tree=100,\n tree_method='auto',\n min_tree_child_weight=1,\n colsample_bytree=1.0,\n colsample_bylevel=1.0,\n colsample_bynode=0.8,\n min_split_loss=0.0,\n max_tree_depth=15,\n subsample=0.8,\n l1_reg=0.0,\n l2_reg=1.0,\n min_rel_progress=0.01,\n data_split_method='no_split',\n enable_global_explain=False,\n xgboost_version='0.9',\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + ) + + +def test_randomforest_regressor_params_fit( + bqml_model_factory, mock_session, mock_X, mock_y +): + model = ensemble.RandomForestRegressor( + data_split_method="seq", + data_split_eval_fraction=0.2, + data_split_col="split_col", + ) + model._bqml_model_factory = bqml_model_factory + model.fit(mock_X, mock_y) + + mock_session._start_query_ml_ddl.assert_called_once_with( + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='RANDOM_FOREST_REGRESSOR',\n early_stop=True,\n num_parallel_tree=100,\n tree_method='auto',\n min_tree_child_weight=1,\n colsample_bytree=1.0,\n colsample_bylevel=1.0,\n colsample_bynode=0.8,\n min_split_loss=0.0,\n max_tree_depth=15,\n subsample=0.8,\n l1_reg=0.0,\n l2_reg=1.0,\n min_rel_progress=0.01,\n data_split_method='seq',\n enable_global_explain=False,\n xgboost_version='0.9',\n data_split_eval_fraction=0.2,\n data_split_col='split_col',\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + ) + + +def test_randomforest_classifier_default_fit( + bqml_model_factory, mock_session, mock_X, mock_y +): + model = ensemble.RandomForestClassifier() + model._bqml_model_factory = bqml_model_factory + model.fit(mock_X, mock_y) + + mock_session._start_query_ml_ddl.assert_called_once_with( + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='RANDOM_FOREST_CLASSIFIER',\n early_stop=True,\n num_parallel_tree=100,\n tree_method='auto',\n min_tree_child_weight=1,\n colsample_bytree=1.0,\n colsample_bylevel=1.0,\n colsample_bynode=0.8,\n min_split_loss=0.0,\n max_tree_depth=15,\n subsample=0.8,\n l1_reg=0.0,\n l2_reg=1.0,\n min_rel_progress=0.01,\n data_split_method='no_split',\n enable_global_explain=False,\n xgboost_version='0.9',\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + ) + + +def test_randomforest_classifier_params_fit( + bqml_model_factory, mock_session, mock_X, mock_y +): + model = ensemble.RandomForestClassifier( + data_split_method="seq", + data_split_eval_fraction=0.2, + data_split_col="split_col", + ) + model._bqml_model_factory = bqml_model_factory + model.fit(mock_X, mock_y) + + mock_session._start_query_ml_ddl.assert_called_once_with( + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='RANDOM_FOREST_CLASSIFIER',\n early_stop=True,\n num_parallel_tree=100,\n tree_method='auto',\n min_tree_child_weight=1,\n colsample_bytree=1.0,\n colsample_bylevel=1.0,\n colsample_bynode=0.8,\n min_split_loss=0.0,\n max_tree_depth=15,\n subsample=0.8,\n l1_reg=0.0,\n l2_reg=1.0,\n min_rel_progress=0.01,\n data_split_method='seq',\n enable_global_explain=False,\n xgboost_version='0.9',\n data_split_eval_fraction=0.2,\n data_split_col='split_col',\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" + ) diff --git a/third_party/bigframes_vendored/sklearn/ensemble/_forest.py b/third_party/bigframes_vendored/sklearn/ensemble/_forest.py index 92794bb68e..8db6322510 100644 --- a/third_party/bigframes_vendored/sklearn/ensemble/_forest.py +++ b/third_party/bigframes_vendored/sklearn/ensemble/_forest.py @@ -122,6 +122,21 @@ class RandomForestRegressor(ForestRegressor): Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False. xgboost_version (Optional[str]): Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1". + data_split_method (str, default "no_split"): + The method to split input data into training and evaluation sets. + Possible values are "auto_split", "random", "custom", "seq" and + "no_split". Default to "no_split". For details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_method). + data_split_eval_fraction (float or None, default None): + Specifies the fraction of the data used for evaluation. Accurate to + two decimal places. Default to None, in which all the data would be + used for training and evaluation. For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_eval_fraction). + data_split_col (str or None, default None): + Identifies the column used to split the data when + ``data_split_method`` is set to "custom" or "seq". Default to None. + For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_col). """ @@ -186,5 +201,20 @@ class RandomForestClassifier(ForestClassifier): enable_global_explain (Optional[bool]): Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False. xgboost_version (Optional[str]): - Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1".ß + Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1". + data_split_method (str, default "no_split"): + The method to split input data into training and evaluation sets. + Possible values are "auto_split", "random", "custom", "seq" and + "no_split". Default to "no_split". For details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_method). + data_split_eval_fraction (float or None, default None): + Specifies the fraction of the data used for evaluation. Accurate to + two decimal places. Default to None, in which all the data would be + used for training and evaluation. For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_eval_fraction). + data_split_col (str or None, default None): + Identifies the column used to split the data when + ``data_split_method`` is set to "custom" or "seq". Default to None. + For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_col). """ diff --git a/third_party/bigframes_vendored/xgboost/sklearn.py b/third_party/bigframes_vendored/xgboost/sklearn.py index da1396af02..d2b56a79a5 100644 --- a/third_party/bigframes_vendored/xgboost/sklearn.py +++ b/third_party/bigframes_vendored/xgboost/sklearn.py @@ -94,6 +94,21 @@ class XGBRegressor(XGBModel, XGBRegressorBase): Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False. xgboost_version (Optional[str]): Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1". + data_split_method (str, default "no_split"): + The method to split input data into training and evaluation sets. + Possible values are "auto_split", "random", "custom", "seq" and + "no_split". Default to "no_split". For details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_method). + data_split_eval_fraction (float or None, default None): + Specifies the fraction of the data used for evaluation. Accurate to + two decimal places. Default to None, in which all the data would be + used for training and evaluation. For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_eval_fraction). + data_split_col (str or None, default None): + Identifies the column used to split the data when + ``data_split_method`` is set to "custom" or "seq". Default to None. + For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_col). """ @@ -141,4 +156,19 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False. xgboost_version (Optional[str]): Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1". + data_split_method (str, default "no_split"): + The method to split input data into training and evaluation sets. + Possible values are "auto_split", "random", "custom", "seq" and + "no_split". Default to "no_split". For details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_method). + data_split_eval_fraction (float or None, default None): + Specifies the fraction of the data used for evaluation. Accurate to + two decimal places. Default to None, in which all the data would be + used for training and evaluation. For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_eval_fraction). + data_split_col (str or None, default None): + Identifies the column used to split the data when + ``data_split_method`` is set to "custom" or "seq". Default to None. + For more details see + [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-glm#data_split_col). """