diff --git a/bigframes/remote_function.py b/bigframes/remote_function.py index 37c7a2fc64..81ba26600b 100644 --- a/bigframes/remote_function.py +++ b/bigframes/remote_function.py @@ -202,10 +202,22 @@ def create_bq_remote_function( OPTIONS ( endpoint = "{endpoint}" )""" + logger.info(f"Creating BQ remote function: {create_function_ddl}") + + # Make sure the dataset exists + dataset = bigquery.Dataset( + bigquery.DatasetReference.from_string( + self._bq_dataset, default_project=self._gcp_project_id + ) + ) + dataset.location = self._bq_location + self._bq_client.create_dataset(dataset, exists_ok=True) + # TODO: Use session._start_query() so we get progress bar query_job = self._bq_client.query(create_function_ddl) # Make an API request. query_job.result() # Wait for the job to complete. + logger.info(f"Created remote function {query_job.ddl_target_routine}") def get_cloud_function_fully_qualified_parent(self): @@ -465,17 +477,22 @@ def get_remote_function_specs(self, remote_function_name): routines = self._bq_client.list_routines( f"{self._gcp_project_id}.{self._bq_dataset}" ) - for routine in routines: - if routine.reference.routine_id == remote_function_name: - # TODO(shobs): Use first class properties when they are available - # https://github.com/googleapis/python-bigquery/issues/1552 - rf_options = routine._properties.get("remoteFunctionOptions") - if rf_options: - http_endpoint = rf_options.get("endpoint") - bq_connection = rf_options.get("connection") - if bq_connection: - bq_connection = os.path.basename(bq_connection) - break + try: + for routine in routines: + if routine.reference.routine_id == remote_function_name: + # TODO(shobs): Use first class properties when they are available + # https://github.com/googleapis/python-bigquery/issues/1552 + rf_options = routine._properties.get("remoteFunctionOptions") + if rf_options: + http_endpoint = rf_options.get("endpoint") + bq_connection = rf_options.get("connection") + if bq_connection: + bq_connection = os.path.basename(bq_connection) + break + except google.api_core.exceptions.NotFound: + # The dataset might not exist, in which case the http_endpoint doesn't, either. + # Note: list_routines doesn't make an API request until we iterate on the response object. + pass return (http_endpoint, bq_connection) diff --git a/bigframes/session.py b/bigframes/session.py index a7cb78e3ff..6ad65000ce 100644 --- a/bigframes/session.py +++ b/bigframes/session.py @@ -381,17 +381,12 @@ def _create_and_bind_bq_session(self): ] ) - # Dataset for storing BQML models and remote functions, which don't yet + # Dataset for storing remote functions, which don't yet # support proper session temporary storage yet self._session_dataset = bigquery.Dataset( f"{self.bqclient.project}.bigframes_temp_{self._location.lower().replace('-', '_')}" ) self._session_dataset.location = self._location - self._session_dataset.default_table_expiration_ms = 24 * 60 * 60 * 1000 - - # TODO: handle case when the dataset does not exist and the user does - # not have permission to create one (bigquery.datasets.create IAM) - self.bqclient.create_dataset(self._session_dataset, exists_ok=True) def close(self): """Terminated the BQ session, otherwises the session will be terminated automatically after diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 3153bd1559..ed22a3e8da 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -134,15 +134,28 @@ def cleanup_datasets(bigquery_client: bigquery.Client) -> None: ) +def get_dataset_id(project_id: str): + "Get a fully qualified dataset id belonging to the given project." + dataset_id = f"{project_id}.{prefixer.create_prefix()}_dataset_id" + return dataset_id + + @pytest.fixture(scope="session") def dataset_id(bigquery_client: bigquery.Client): """Create (and cleanup) a temporary dataset.""" - project_id = bigquery_client.project - dataset_id = f"{project_id}.{prefixer.create_prefix()}_dataset_id" - dataset = bigquery.Dataset(dataset_id) - bigquery_client.create_dataset(dataset) + dataset_id = get_dataset_id(bigquery_client.project) + bigquery_client.create_dataset(dataset_id) + yield dataset_id + bigquery_client.delete_dataset(dataset_id, delete_contents=True) + + +@pytest.fixture +def dataset_id_not_created(bigquery_client: bigquery.Client): + """Return a temporary dataset object without creating it, and clean it up + after it has been used.""" + dataset_id = get_dataset_id(bigquery_client.project) yield dataset_id - bigquery_client.delete_dataset(dataset, delete_contents=True) + bigquery_client.delete_dataset(dataset_id, delete_contents=True) @pytest.fixture(scope="session") diff --git a/tests/system/large/test_remote_function.py b/tests/system/large/test_remote_function.py index 2f231f40c9..f270099182 100644 --- a/tests/system/large/test_remote_function.py +++ b/tests/system/large/test_remote_function.py @@ -408,6 +408,49 @@ def add_one(x): ) +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_explicit_dataset_not_created( + session, scalars_dfs, dataset_id_not_created, bq_cf_connection, functions_client +): + try: + + @session.remote_function( + [int], + int, + dataset_id_not_created, + bq_cf_connection, + reuse=False, + ) + def square(x): + return x * x + + scalars_df, scalars_pandas_df = scalars_dfs + + bf_int64_col = scalars_df["int64_col"] + bf_int64_col_filter = bf_int64_col.notnull() + bf_int64_col_filtered = bf_int64_col[bf_int64_col_filter] + bf_result_col = bf_int64_col_filtered.apply(square) + bf_result = ( + bf_int64_col_filtered.to_frame().assign(result=bf_result_col).to_pandas() + ) + + pd_int64_col = scalars_pandas_df["int64_col"] + pd_int64_col_filter = pd_int64_col.notnull() + pd_int64_col_filtered = pd_int64_col[pd_int64_col_filter] + pd_result_col = pd_int64_col_filtered.apply(lambda x: x * x) + # TODO(shobs): Figure why pandas .apply() changes the dtype, i.e. + # pd_int64_col_filtered.dtype is Int64Dtype() + # pd_int64_col_filtered.apply(lambda x: x * x).dtype is int64. + # For this test let's force the pandas dtype to be same as bigframes' dtype. + pd_result_col = pd_result_col.astype(pandas.Int64Dtype()) + pd_result = pd_int64_col_filtered.to_frame().assign(result=pd_result_col) + + assert_pandas_df_equal_ignore_ordering(bf_result, pd_result) + finally: + # clean up the gcp assets created for the remote function + cleanup_remote_function_assets(session.bqclient, functions_client, square) + + @pytest.mark.flaky(retries=2, delay=120) def test_remote_udf_referring_outside_var( session, scalars_dfs, dataset_id, bq_cf_connection, functions_client diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index 53ddfa3c49..7655325bfc 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -894,11 +894,6 @@ def test_session_id(session): # TODO(chelsealin): Verify the session id can be binded with a load job. -def test_session_dataset_exists_and_configured(session: bigframes.Session): - dataset = session.bqclient.get_dataset(session._session_dataset_id) - assert dataset.default_table_expiration_ms == 24 * 60 * 60 * 1000 - - @pytest.mark.flaky(retries=2) def test_to_close_session(): session = bigframes.Session()