Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

test: include model.register test for BQML CMEK #433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions 4 bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
options={"vertex_ai_model_id": vertex_ai_model_id}
)
# Register the model and wait it to finish
self._session._start_query_create_model(sql)
self._session._start_query_ml_ddl(sql)

self._model = self._session.bqclient.get_model(self.model_name)
return self
Expand All @@ -264,7 +264,7 @@ def _create_model_ref(

def _create_model_with_sql(self, session: bigframes.Session, sql: str) -> BqmlModel:
# fit the model, synchronously
_, job = session._start_query_create_model(sql)
_, job = session._start_query_ml_ddl(sql)

# real model path in the session specific hidden dataset and table prefix
model_name_full = f"{job.destination.project}.{job.destination.dataset_id}.{job.destination.table_id}"
Expand Down
5 changes: 3 additions & 2 deletions 5 bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,12 +1592,13 @@ def _start_query(
self.bqclient, sql, job_config, max_results
)

def _start_query_create_model(
def _start_query_ml_ddl(
self,
sql: str,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts BigQuery ML CREATE MODEL query job and waits for results.
Starts BigQuery ML DDL query job (CREATE MODEL/ALTER MODEL/...) and
waits for results.
"""
job_config = self._prepare_query_job_config()

Expand Down
41 changes: 38 additions & 3 deletions 41 tests/system/small/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_df_apis(bq_cmek, session_with_bq_cmek, scalars_table_id):
# Read a BQ table and assert encryption
df = session_with_bq_cmek.read_gbq(scalars_table_id)

# Perform a few dataframe operations and assert assertion
# Perform a few dataframe operations and assert encryption
df1 = df.dropna()
_assert_bq_table_is_encrypted(df1, bq_cmek, session_with_bq_cmek)

Expand Down Expand Up @@ -179,15 +179,32 @@ def test_to_gbq(bq_cmek, session_with_bq_cmek, scalars_table_id):
df = session_with_bq_cmek.read_gbq(scalars_table_id)
_assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek)

# Modify the dataframe and assert assertion
# Modify the dataframe and assert encryption
df = df.dropna().head()
_assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek)

# Write the result to BQ and assert assertion
# Write the result to BQ and assert encryption
output_table_id = df.to_gbq()
output_table = session_with_bq_cmek.bqclient.get_table(output_table_id)
assert output_table.encryption_configuration.kms_key_name == bq_cmek

# Write the result to BQ custom table and assert encryption
session_with_bq_cmek.bqclient.get_table(output_table_id)
output_table_ref = bigframes.session._io.bigquery.random_table(
session_with_bq_cmek._anonymous_dataset
)
output_table_id = str(output_table_ref)
df.to_gbq(output_table_id)
output_table = session_with_bq_cmek.bqclient.get_table(output_table_id)
assert output_table.encryption_configuration.kms_key_name == bq_cmek

# Lastly, assert that the encryption is not because of any default set at
# the dataset level
output_table_dataset = session_with_bq_cmek.bqclient.get_dataset(
output_table.dataset_id
)
assert output_table_dataset.default_encryption_configuration is None


@pytest.mark.skip(
reason="Internal issue 327544164, cmek does not propagate to the dataframe."
Expand Down Expand Up @@ -254,3 +271,21 @@ def test_bqml(bq_cmek, session_with_bq_cmek, penguins_table_id):
# Assert that model exists in BQ with intended encryption
model_bq = session_with_bq_cmek.bqclient.get_model(new_model._bqml_model.model_name)
assert model_bq.encryption_configuration.kms_key_name == bq_cmek

# Assert that model registration keeps the encryption
# Note that model registration only creates an entry (metadata) to be
# included in the Vertex AI Model Registry. See for more details
# https://cloud.google.com/bigquery/docs/update_vertex#add-existing.
# When use deploys the model to an endpoint from the Model Registry then
# they can specify an encryption key to further protect the artifacts at
# rest on the Vertex AI side. See for more details:
# https://cloud.google.com/vertex-ai/docs/general/deployment#deploy_a_model_to_an_endpoint,
# https://cloud.google.com/vertex-ai/docs/general/cmek#create_resources_with_the_kms_key.
# bigframes.ml does not provide any API for the model deployment.
model_registered = new_model.register()
assert (
model_registered._bqml_model.model.encryption_configuration.kms_key_name
== bq_cmek
)
model_bq = session_with_bq_cmek.bqclient.get_model(new_model._bqml_model.model_name)
assert model_bq.encryption_configuration.kms_key_name == bq_cmek
10 changes: 5 additions & 5 deletions 10 tests/unit/ml/test_golden_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def mock_session():
mock_session._anonymous_dataset, TEMP_MODEL_ID.model_id
)
)
mock_session._start_query_create_model.return_value = (None, query_job)
mock_session._start_query_ml_ddl.return_value = (None, query_job)

return mock_session

Expand Down Expand Up @@ -104,7 +104,7 @@ def test_linear_regression_default_fit(
model._bqml_model_factory = bqml_model_factory
model.fit(mock_X, mock_y)

mock_session._start_query_create_model.assert_called_once_with(
mock_session._start_query_ml_ddl.assert_called_once_with(
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n ls_init_learn_rate=0.1,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
)

Expand All @@ -114,7 +114,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X,
model._bqml_model_factory = bqml_model_factory
model.fit(mock_X, mock_y)

mock_session._start_query_create_model.assert_called_once_with(
mock_session._start_query_ml_ddl.assert_called_once_with(
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n ls_init_learn_rate=0.1,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
)

Expand Down Expand Up @@ -147,7 +147,7 @@ def test_logistic_regression_default_fit(
model._bqml_model_factory = bqml_model_factory
model.fit(mock_X, mock_y)

mock_session._start_query_create_model.assert_called_once_with(
mock_session._start_query_ml_ddl.assert_called_once_with(
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=True,\n auto_class_weights=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
)

Expand All @@ -161,7 +161,7 @@ def test_logistic_regression_params_fit(
model._bqml_model_factory = bqml_model_factory
model.fit(mock_X, mock_y)

mock_session._start_query_create_model.assert_called_once_with(
mock_session._start_query_ml_ddl.assert_called_once_with(
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=False,\n auto_class_weights=True,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
)

Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.