From 53734669f6a489e47822b159fdd84e2e15e30cd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a=20=28Swast=29?= Date: Fri, 10 May 2024 10:59:14 -0500 Subject: [PATCH] docs: use `class_weight="balanced"` in the logistic regression prediction tutorial This aligns the Python code with the SQL at https://cloud.google.com/bigquery/docs/logistic-regression-prediction#create_a_logistic_regression_model ```sql CREATE OR REPLACE MODEL `census.census_model` OPTIONS ( model_type='LOGISTIC_REG', auto_class_weights=TRUE, data_split_method='NO_SPLIT', input_label_cols=['income_bracket'], max_iterations=15) AS SELECT * EXCEPT(dataframe) FROM `census.input_data` WHERE dataframe = 'training' ``` --- .../logistic_regression_prediction_test.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/samples/snippets/logistic_regression_prediction_test.py b/samples/snippets/logistic_regression_prediction_test.py index 6a40369ba8..dd92f8f3e3 100644 --- a/samples/snippets/logistic_regression_prediction_test.py +++ b/samples/snippets/logistic_regression_prediction_test.py @@ -80,7 +80,21 @@ def test_logistic_regression_prediction(random_model_id: str) -> None: X = training_data.drop(columns=["income_bracket", "dataframe"]) y = training_data["income_bracket"] - census_model = bigframes.ml.linear_model.LogisticRegression() + census_model = bigframes.ml.linear_model.LogisticRegression( + # Balance the class labels in the training data by setting + # class_weight="balanced". + # + # By default, the training data is unweighted. If the labels + # in the training data are imbalanced, the model may learn to + # predict the most popular class of labels more heavily. In + # this case, most of the respondents in the dataset are in the + # lower income bracket. This may lead to a model that predicts + # the lower income bracket too heavily. Class weights balance + # the class labels by calculating the weights for each class in + # inverse proportion to the frequency of that class. + class_weight="balanced", + max_iterations=15, + ) census_model.fit(X, y) census_model.to_gbq(