diff --git a/bigframes/clients.py b/bigframes/clients.py index dcac611e8c..4ba9d93d69 100644 --- a/bigframes/clients.py +++ b/bigframes/clients.py @@ -29,8 +29,6 @@ ) logger = logging.getLogger(__name__) -_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection" - class BqConnectionManager: """Manager to handle operations with BQ connections.""" @@ -46,6 +44,23 @@ def __init__( self._bq_connection_client = bq_connection_client self._cloud_resource_manager_client = cloud_resource_manager_client + @classmethod + def resolve_full_connection_name( + cls, connection_name: str, default_project: str, default_location: str + ) -> str: + """Retrieve the full connection name of the form ... + Use default project, location or connection_id when any of them are missing.""" + if connection_name.count(".") == 2: + return connection_name + + if connection_name.count(".") == 1: + return f"{default_project}.{connection_name}" + + if connection_name.count(".") == 0: + return f"{default_project}.{default_location}.{connection_name}" + + raise ValueError(f"Invalid connection name format: {connection_name}.") + def create_bq_connection( self, project_id: str, location: str, connection_id: str, iam_role: str ): @@ -164,25 +179,3 @@ def _get_service_account_if_connection_exists( pass return service_account - - -def get_connection_name_full( - connection_name: Optional[str], default_project: str, default_location: str -) -> str: - """Retrieve the full connection name of the form ... - Use default project, location or connection_id when any of them are missing.""" - if connection_name is None: - return ( - f"{default_project}.{default_location}.{_BIGFRAMES_DEFAULT_CONNECTION_ID}" - ) - - if connection_name.count(".") == 2: - return connection_name - - if connection_name.count(".") == 1: - return f"{default_project}.{connection_name}" - - if connection_name.count(".") == 0: - return f"{default_project}.{default_location}.{connection_name}" - - raise ValueError(f"Invalid connection name format: {connection_name}.") diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index a61dd34e6d..d78f467537 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -49,17 +49,17 @@ def __init__( connection_name: Optional[str] = None, ): self.session = session or bpd.get_global_session() + self._bq_connection_manager = clients.BqConnectionManager( + self.session.bqconnectionclient, self.session.resourcemanagerclient + ) connection_name = connection_name or self.session._bq_connection - self.connection_name = clients.get_connection_name_full( + self.connection_name = self._bq_connection_manager.resolve_full_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, ) - self._bq_connection_manager = clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient - ) self._bqml_model_factory = globals.bqml_model_factory() self._bqml_model: core.BqmlModel = self._create_bqml_model() @@ -188,17 +188,17 @@ def __init__( connection_name: Optional[str] = None, ): self.session = session or bpd.get_global_session() + self._bq_connection_manager = clients.BqConnectionManager( + self.session.bqconnectionclient, self.session.resourcemanagerclient + ) connection_name = connection_name or self.session._bq_connection - self.connection_name = clients.get_connection_name_full( + self.connection_name = self._bq_connection_manager.resolve_full_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, ) - self._bq_connection_manager = clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient - ) self._bqml_model_factory = globals.bqml_model_factory() self._bqml_model: core.BqmlModel = self._create_bqml_model() diff --git a/bigframes/remote_function.py b/bigframes/remote_function.py index 81ba26600b..fd9aec825f 100644 --- a/bigframes/remote_function.py +++ b/bigframes/remote_function.py @@ -772,7 +772,7 @@ def remote_function( if not bigquery_connection: bigquery_connection = session._bq_connection # type: ignore - bigquery_connection = clients.get_connection_name_full( + bigquery_connection = clients.BqConnectionManager.resolve_full_connection_name( bigquery_connection, default_project=dataset_ref.project, default_location=bq_location, diff --git a/bigframes/session.py b/bigframes/session.py index 6ad65000ce..4f509f0704 100644 --- a/bigframes/session.py +++ b/bigframes/session.py @@ -97,6 +97,8 @@ _BIGQUERYCONNECTION_REGIONAL_ENDPOINT = "{location}-bigqueryconnection.googleapis.com" _BIGQUERYSTORAGE_REGIONAL_ENDPOINT = "{location}-bigquerystorage.googleapis.com" +_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection" + _MAX_CLUSTER_COLUMNS = 4 # TODO(swast): Need to connect to regional endpoints when performing remote @@ -321,7 +323,7 @@ def __init__( ), ) - self._bq_connection = context.bq_connection + self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID # Now that we're starting the session, don't allow the options to be # changed. diff --git a/tests/unit/test_clients.py b/tests/unit/test_clients.py index a90e5b0320..f89cc21397 100644 --- a/tests/unit/test_clients.py +++ b/tests/unit/test_clients.py @@ -17,29 +17,22 @@ from bigframes import clients -def test_get_connection_name_full_none(): - connection_name = clients.get_connection_name_full( - None, default_project="default-project", default_location="us" - ) - assert connection_name == "default-project.us.bigframes-default-connection" - - def test_get_connection_name_full_connection_id(): - connection_name = clients.get_connection_name_full( + connection_name = clients.BqConnectionManager.resolve_full_connection_name( "connection-id", default_project="default-project", default_location="us" ) assert connection_name == "default-project.us.connection-id" def test_get_connection_name_full_location_connection_id(): - connection_name = clients.get_connection_name_full( + connection_name = clients.BqConnectionManager.resolve_full_connection_name( "eu.connection-id", default_project="default-project", default_location="us" ) assert connection_name == "default-project.eu.connection-id" def test_get_connection_name_full_all(): - connection_name = clients.get_connection_name_full( + connection_name = clients.BqConnectionManager.resolve_full_connection_name( "my-project.eu.connection-id", default_project="default-project", default_location="us", @@ -48,9 +41,8 @@ def test_get_connection_name_full_all(): def test_get_connection_name_full_raise_value_error(): - with pytest.raises(ValueError): - clients.get_connection_name_full( + clients.BqConnectionManager.resolve_full_connection_name( "my-project.eu.connection-id.extra_field", default_project="default-project", default_location="us",