Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

chore: address last comments of PR#87 #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions 41 bigframes/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
)
logger = logging.getLogger(__name__)

_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection"


class BqConnectionManager:
"""Manager to handle operations with BQ connections."""
Expand All @@ -46,6 +44,23 @@ def __init__(
self._bq_connection_client = bq_connection_client
self._cloud_resource_manager_client = cloud_resource_manager_client

@classmethod
def resolve_full_connection_name(
cls, connection_name: str, default_project: str, default_location: str
) -> str:
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
Use default project, location or connection_id when any of them are missing."""
if connection_name.count(".") == 2:
return connection_name

if connection_name.count(".") == 1:
return f"{default_project}.{connection_name}"

if connection_name.count(".") == 0:
return f"{default_project}.{default_location}.{connection_name}"

raise ValueError(f"Invalid connection name format: {connection_name}.")

def create_bq_connection(
self, project_id: str, location: str, connection_id: str, iam_role: str
):
Expand Down Expand Up @@ -164,25 +179,3 @@ def _get_service_account_if_connection_exists(
pass

return service_account


def get_connection_name_full(
connection_name: Optional[str], default_project: str, default_location: str
) -> str:
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
Use default project, location or connection_id when any of them are missing."""
if connection_name is None:
return (
f"{default_project}.{default_location}.{_BIGFRAMES_DEFAULT_CONNECTION_ID}"
)

if connection_name.count(".") == 2:
return connection_name

if connection_name.count(".") == 1:
return f"{default_project}.{connection_name}"

if connection_name.count(".") == 0:
return f"{default_project}.{default_location}.{connection_name}"

raise ValueError(f"Invalid connection name format: {connection_name}.")
16 changes: 8 additions & 8 deletions 16 bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ def __init__(
connection_name: Optional[str] = None,
):
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)

connection_name = connection_name or self.session._bq_connection
self.connection_name = clients.get_connection_name_full(
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
)

self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bqml_model_factory = globals.bqml_model_factory()
self._bqml_model: core.BqmlModel = self._create_bqml_model()

Expand Down Expand Up @@ -188,17 +188,17 @@ def __init__(
connection_name: Optional[str] = None,
):
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)

connection_name = connection_name or self.session._bq_connection
self.connection_name = clients.get_connection_name_full(
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea to keep it classmethod was to use BqConnectionManager.resolve_full_connection_name(...) directly. Using via instance feels a bit weird, but it works.

connection_name,
default_project=self.session._project,
default_location=self.session._location,
)

self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bqml_model_factory = globals.bqml_model_factory()
self._bqml_model: core.BqmlModel = self._create_bqml_model()

Expand Down
2 changes: 1 addition & 1 deletion 2 bigframes/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def remote_function(
if not bigquery_connection:
bigquery_connection = session._bq_connection # type: ignore

bigquery_connection = clients.get_connection_name_full(
bigquery_connection = clients.BqConnectionManager.resolve_full_connection_name(
bigquery_connection,
default_project=dataset_ref.project,
default_location=bq_location,
Expand Down
4 changes: 3 additions & 1 deletion 4 bigframes/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@
_BIGQUERYCONNECTION_REGIONAL_ENDPOINT = "{location}-bigqueryconnection.googleapis.com"
_BIGQUERYSTORAGE_REGIONAL_ENDPOINT = "{location}-bigquerystorage.googleapis.com"

_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection"

_MAX_CLUSTER_COLUMNS = 4

# TODO(swast): Need to connect to regional endpoints when performing remote
Expand Down Expand Up @@ -321,7 +323,7 @@ def __init__(
),
)

self._bq_connection = context.bq_connection
self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID

# Now that we're starting the session, don't allow the options to be
# changed.
Expand Down
16 changes: 4 additions & 12 deletions 16 tests/unit/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,22 @@
from bigframes import clients


def test_get_connection_name_full_none():
connection_name = clients.get_connection_name_full(
None, default_project="default-project", default_location="us"
)
assert connection_name == "default-project.us.bigframes-default-connection"


def test_get_connection_name_full_connection_id():
connection_name = clients.get_connection_name_full(
connection_name = clients.BqConnectionManager.resolve_full_connection_name(
"connection-id", default_project="default-project", default_location="us"
)
assert connection_name == "default-project.us.connection-id"


def test_get_connection_name_full_location_connection_id():
connection_name = clients.get_connection_name_full(
connection_name = clients.BqConnectionManager.resolve_full_connection_name(
"eu.connection-id", default_project="default-project", default_location="us"
)
assert connection_name == "default-project.eu.connection-id"


def test_get_connection_name_full_all():
connection_name = clients.get_connection_name_full(
connection_name = clients.BqConnectionManager.resolve_full_connection_name(
"my-project.eu.connection-id",
default_project="default-project",
default_location="us",
Expand All @@ -48,9 +41,8 @@ def test_get_connection_name_full_all():


def test_get_connection_name_full_raise_value_error():

with pytest.raises(ValueError):
clients.get_connection_name_full(
clients.BqConnectionManager.resolve_full_connection_name(
"my-project.eu.connection-id.extra_field",
default_project="default-project",
default_location="us",
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.