diff --git a/bigframes/_config/bigquery_options.py b/bigframes/_config/bigquery_options.py index ea1864ed5f..eb56de826a 100644 --- a/bigframes/_config/bigquery_options.py +++ b/bigframes/_config/bigquery_options.py @@ -83,12 +83,14 @@ def project(self, value: Optional[str]): @property def bq_connection(self) -> Optional[str]: - """Name of the BigQuery connection to use. + """Name of the BigQuery connection to use. Should be of the form ... You should either have the connection already created in the location you have chosen, or you should have the Project IAM Admin role to enable the service to create the connection for you if you need it. + + If this option isn't provided, or project or location aren't provided, session will use its default project/location/connection_id as default connection. """ return self._bq_connection diff --git a/bigframes/clients.py b/bigframes/clients.py index b60fcba04a..dcac611e8c 100644 --- a/bigframes/clients.py +++ b/bigframes/clients.py @@ -29,6 +29,8 @@ ) logger = logging.getLogger(__name__) +_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection" + class BqConnectionManager: """Manager to handle operations with BQ connections.""" @@ -162,3 +164,25 @@ def _get_service_account_if_connection_exists( pass return service_account + + +def get_connection_name_full( + connection_name: Optional[str], default_project: str, default_location: str +) -> str: + """Retrieve the full connection name of the form ... + Use default project, location or connection_id when any of them are missing.""" + if connection_name is None: + return ( + f"{default_project}.{default_location}.{_BIGFRAMES_DEFAULT_CONNECTION_ID}" + ) + + if connection_name.count(".") == 2: + return connection_name + + if connection_name.count(".") == 1: + return f"{default_project}.{connection_name}" + + if connection_name.count(".") == 0: + return f"{default_project}.{default_location}.{connection_name}" + + raise ValueError(f"Invalid connection name format: {connection_name}.") diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index c86e5fb3b6..a61dd34e6d 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -38,8 +38,9 @@ class PaLM2TextGenerator(base.Predictor): session (bigframes.Session or None): BQ session to create the model. If None, use the global default session. connection_name (str or None): - connection to connect with remote service. str of the format ... - if None, use default connection in session context. + connection to connect with remote service. str of the format ... + if None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach + permission if the connection isn't fully setup. """ def __init__( @@ -48,7 +49,14 @@ def __init__( connection_name: Optional[str] = None, ): self.session = session or bpd.get_global_session() - self.connection_name = connection_name or self.session._bq_connection + + connection_name = connection_name or self.session._bq_connection + self.connection_name = clients.get_connection_name_full( + connection_name, + default_project=self.session._project, + default_location=self.session._location, + ) + self._bq_connection_manager = clients.BqConnectionManager( self.session.bqconnectionclient, self.session.resourcemanagerclient ) @@ -180,7 +188,14 @@ def __init__( connection_name: Optional[str] = None, ): self.session = session or bpd.get_global_session() - self.connection_name = connection_name or self.session._bq_connection + + connection_name = connection_name or self.session._bq_connection + self.connection_name = clients.get_connection_name_full( + connection_name, + default_project=self.session._project, + default_location=self.session._location, + ) + self._bq_connection_manager = clients.BqConnectionManager( self.session.bqconnectionclient, self.session.resourcemanagerclient ) diff --git a/bigframes/remote_function.py b/bigframes/remote_function.py index 6fc2f8e59f..37c7a2fc64 100644 --- a/bigframes/remote_function.py +++ b/bigframes/remote_function.py @@ -695,9 +695,12 @@ def remote_function( persistent name. """ + import bigframes.pandas as bpd + + session = session or bpd.get_global_session() # A BigQuery client is required to perform BQ operations - if not bigquery_client and session: + if not bigquery_client: bigquery_client = session.bqclient if not bigquery_client: raise ValueError( @@ -706,7 +709,7 @@ def remote_function( ) # A BigQuery connection client is required to perform BQ connection operations - if not bigquery_connection_client and session: + if not bigquery_connection_client: bigquery_connection_client = session.bqconnectionclient if not bigquery_connection_client: raise ValueError( @@ -716,8 +719,7 @@ def remote_function( # A cloud functions client is required to perform cloud functions operations if not cloud_functions_client: - if session: - cloud_functions_client = session.cloudfunctionsclient + cloud_functions_client = session.cloudfunctionsclient if not cloud_functions_client: raise ValueError( "A cloud functions client must be provided, either directly or via session. " @@ -726,8 +728,7 @@ def remote_function( # A resource manager client is required to get/set IAM operations if not resource_manager_client: - if session: - resource_manager_client = session.resourcemanagerclient + resource_manager_client = session.resourcemanagerclient if not resource_manager_client: raise ValueError( "A resource manager client must be provided, either directly or via session. " @@ -740,15 +741,10 @@ def remote_function( dataset_ref = bigquery.DatasetReference.from_string( dataset, default_project=bigquery_client.project ) - elif session: + else: dataset_ref = bigquery.DatasetReference.from_string( session._session_dataset_id, default_project=bigquery_client.project ) - else: - raise ValueError( - "Project and dataset must be provided, either directly or via session. " - f"{constants.FEEDBACK_LINK}" - ) bq_location, cloud_function_region = get_remote_function_locations( bigquery_client.location @@ -756,40 +752,30 @@ def remote_function( # A connection is required for BQ remote function # https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function - if not bigquery_connection and session: - bigquery_connection = session._bq_connection # type: ignore if not bigquery_connection: + bigquery_connection = session._bq_connection # type: ignore + + bigquery_connection = clients.get_connection_name_full( + bigquery_connection, + default_project=dataset_ref.project, + default_location=bq_location, + ) + # Guaranteed to be the form of .. + ( + gcp_project_id, + bq_connection_location, + bq_connection_id, + ) = bigquery_connection.split(".") + if gcp_project_id.casefold() != dataset_ref.project.casefold(): raise ValueError( - "BigQuery connection must be provided, either directly or via session. " - f"{constants.FEEDBACK_LINK}" + "The project_id does not match BigQuery connection gcp_project_id: " + f"{dataset_ref.project}." + ) + if bq_connection_location.casefold() != bq_location.casefold(): + raise ValueError( + "The location does not match BigQuery connection location: " + f"{bq_location}." ) - - # Check connection_id with `LOCATION.CONNECTION_ID` or `PROJECT_ID.LOCATION.CONNECTION_ID` format. - if bigquery_connection.count(".") == 1: - bq_connection_location, bq_connection_id = bigquery_connection.split(".") - if bq_connection_location.casefold() != bq_location.casefold(): - raise ValueError( - "The location does not match BigQuery connection location: " - f"{bq_location}." - ) - bigquery_connection = bq_connection_id - elif bigquery_connection.count(".") == 2: - ( - gcp_project_id, - bq_connection_location, - bq_connection_id, - ) = bigquery_connection.split(".") - if gcp_project_id.casefold() != dataset_ref.project.casefold(): - raise ValueError( - "The project_id does not match BigQuery connection gcp_project_id: " - f"{dataset_ref.project}." - ) - if bq_connection_location.casefold() != bq_location.casefold(): - raise ValueError( - "The location does not match BigQuery connection location: " - f"{bq_location}." - ) - bigquery_connection = bq_connection_id def wrapper(f): if not callable(f): @@ -808,7 +794,7 @@ def wrapper(f): dataset_ref.dataset_id, bigquery_client, bigquery_connection_client, - bigquery_connection, + bq_connection_id, resource_manager_client, ) diff --git a/bigframes/session.py b/bigframes/session.py index ac48c977cb..a7cb78e3ff 100644 --- a/bigframes/session.py +++ b/bigframes/session.py @@ -350,10 +350,14 @@ def resourcemanagerclient(self): @property def _session_dataset_id(self): """A dataset for storing temporary objects local to the session - This is a workaround for BQML models and remote functions that do not + This is a workaround for remote functions that do not yet support session-temporary instances.""" return self._session_dataset.dataset_id + @property + def _project(self): + return self.bqclient.project + def _create_and_bind_bq_session(self): """Create a BQ session and bind the session id with clients to capture BQ activities: go/bigframes-transient-data""" diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 7486277487..e546c09f97 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -26,7 +26,8 @@ def test_create_text_generator_model(palm2_text_generator_model): assert palm2_text_generator_model._bqml_model is not None -def test_create_text_generator_model_defaults(bq_connection): +@pytest.mark.flaky(retries=2, delay=120) +def test_create_text_generator_model_default_session(bq_connection, llm_text_pandas_df): import bigframes.pandas as bpd bpd.reset_session() @@ -36,6 +37,40 @@ def test_create_text_generator_model_defaults(bq_connection): model = llm.PaLM2TextGenerator() assert model is not None assert model._bqml_model is not None + assert model.connection_name.casefold() == "bigframes-dev.us.bigframes-rf-conn" + + llm_text_df = bpd.read_pandas(llm_text_pandas_df) + + df = model.predict(llm_text_df).to_pandas() + TestCase().assertSequenceEqual(df.shape, (3, 1)) + assert "ml_generate_text_llm_result" in df.columns + series = df["ml_generate_text_llm_result"] + assert all(series.str.len() > 20) + + +@pytest.mark.flaky(retries=2, delay=120) +def test_create_text_generator_model_default_connection(llm_text_pandas_df): + from bigframes import _config + import bigframes.pandas as bpd + + bpd.reset_session() + _config.options = _config.Options() # reset configs + + llm_text_df = bpd.read_pandas(llm_text_pandas_df) + + model = llm.PaLM2TextGenerator() + assert model is not None + assert model._bqml_model is not None + assert ( + model.connection_name.casefold() + == "bigframes-dev.us.bigframes-default-connection" + ) + + df = model.predict(llm_text_df).to_pandas() + TestCase().assertSequenceEqual(df.shape, (3, 1)) + assert "ml_generate_text_llm_result" in df.columns + series = df["ml_generate_text_llm_result"] + assert all(series.str.len() > 20) # Marked as flaky only because BQML LLM is in preview, the service only has limited capacity, not stable enough. diff --git a/tests/system/small/test_remote_function.py b/tests/system/small/test_remote_function.py index 77fb81d2c9..d024a57ded 100644 --- a/tests/system/small/test_remote_function.py +++ b/tests/system/small/test_remote_function.py @@ -20,6 +20,7 @@ import bigframes from bigframes import remote_function as rf +import bigframes.pandas as bpd from tests.system.utils import assert_pandas_df_equal_ignore_ordering @@ -465,6 +466,36 @@ def square(x): assert_pandas_df_equal_ignore_ordering(bf_result, pd_result) +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_default_connection(scalars_dfs, dataset_id): + @bpd.remote_function([int], int, dataset=dataset_id) + def square(x): + return x * x + + scalars_df, scalars_pandas_df = scalars_dfs + + bf_int64_col = scalars_df["int64_col"] + bf_int64_col_filter = bf_int64_col.notnull() + bf_int64_col_filtered = bf_int64_col[bf_int64_col_filter] + bf_result_col = bf_int64_col_filtered.apply(square) + bf_result = ( + bf_int64_col_filtered.to_frame().assign(result=bf_result_col).to_pandas() + ) + + pd_int64_col = scalars_pandas_df["int64_col"] + pd_int64_col_filter = pd_int64_col.notnull() + pd_int64_col_filtered = pd_int64_col[pd_int64_col_filter] + pd_result_col = pd_int64_col_filtered.apply(lambda x: x * x) + # TODO(shobs): Figure why pandas .apply() changes the dtype, i.e. + # pd_int64_col_filtered.dtype is Int64Dtype() + # pd_int64_col_filtered.apply(lambda x: x * x).dtype is int64. + # For this test let's force the pandas dtype to be same as bigframes' dtype. + pd_result_col = pd_result_col.astype(pd.Int64Dtype()) + pd_result = pd_int64_col_filtered.to_frame().assign(result=pd_result_col) + + assert_pandas_df_equal_ignore_ordering(bf_result, pd_result) + + @pytest.mark.flaky(retries=2, delay=120) def test_dataframe_applymap(session_with_bq_connection, scalars_dfs): def add_one(x): diff --git a/tests/unit/test_clients.py b/tests/unit/test_clients.py new file mode 100644 index 0000000000..a90e5b0320 --- /dev/null +++ b/tests/unit/test_clients.py @@ -0,0 +1,57 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from bigframes import clients + + +def test_get_connection_name_full_none(): + connection_name = clients.get_connection_name_full( + None, default_project="default-project", default_location="us" + ) + assert connection_name == "default-project.us.bigframes-default-connection" + + +def test_get_connection_name_full_connection_id(): + connection_name = clients.get_connection_name_full( + "connection-id", default_project="default-project", default_location="us" + ) + assert connection_name == "default-project.us.connection-id" + + +def test_get_connection_name_full_location_connection_id(): + connection_name = clients.get_connection_name_full( + "eu.connection-id", default_project="default-project", default_location="us" + ) + assert connection_name == "default-project.eu.connection-id" + + +def test_get_connection_name_full_all(): + connection_name = clients.get_connection_name_full( + "my-project.eu.connection-id", + default_project="default-project", + default_location="us", + ) + assert connection_name == "my-project.eu.connection-id" + + +def test_get_connection_name_full_raise_value_error(): + + with pytest.raises(ValueError): + clients.get_connection_name_full( + "my-project.eu.connection-id.extra_field", + default_project="default-project", + default_location="us", + )