Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

feat: use default session connection #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 11, 2023
4 changes: 3 additions & 1 deletion 4 bigframes/_config/bigquery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ def project(self, value: Optional[str]):

@property
def bq_connection(self) -> Optional[str]:
"""Name of the BigQuery connection to use.
"""Name of the BigQuery connection to use. Should be of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.

You should either have the connection already created in the
<code>location</code> you have chosen, or you should have the Project IAM
Admin role to enable the service to create the connection for you if you
need it.

If this option isn't provided, or project or location aren't provided, session will use its default project/location/connection_id as default connection.
"""
return self._bq_connection

Expand Down
24 changes: 24 additions & 0 deletions 24 bigframes/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
)
logger = logging.getLogger(__name__)

_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection"


class BqConnectionManager:
"""Manager to handle operations with BQ connections."""
Expand Down Expand Up @@ -162,3 +164,25 @@ def _get_service_account_if_connection_exists(
pass

return service_account


def get_connection_name_full(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be tidier to make it a class method in BqConnectionManager?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the function is only a helper that does string manipulation. Instead of the BqConnectionManager which contains the states of the clients. So separate them.

Copy link
Contributor

@shobsi shobsi Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO it would still make sense. A @classmethod would mean that it does not depend on the instance level state. But conceptually it would fit in nicely - BqConnectionManager sounds like the entity which is supposed to know the low level details of a connection and can provide helper functions about the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could work. Will address in another PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] since it takes a connection and returns a connection after resolving the format, would prefer naming it resolve_full_connection_name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will address in another PR.

connection_name: Optional[str], default_project: str, default_location: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awkward to have optional argument followed by mandatory arguments, would be nice to have it in the end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional only means it can be value None. It is OK to put it at the front, as it is the most important param of the function.

Not really "optional", which means it has a default value. Then it has to be at the end.

) -> str:
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
Use default project, location or connection_id when any of them are missing."""
if connection_name is None:
return (
f"{default_project}.{default_location}.{_BIGFRAMES_DEFAULT_CONNECTION_ID}"
)

if connection_name.count(".") == 2:
return connection_name

if connection_name.count(".") == 1:
return f"{default_project}.{connection_name}"

if connection_name.count(".") == 0:
return f"{default_project}.{default_location}.{connection_name}"

raise ValueError(f"Invalid connection name format: {connection_name}.")
23 changes: 19 additions & 4 deletions 23 bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ class PaLM2TextGenerator(base.Predictor):
session (bigframes.Session or None):
BQ session to create the model. If None, use the global default session.
connection_name (str or None):
connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<REGION>.<CONNECTION_NAME>.
if None, use default connection in session context.
connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
if None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach
permission if the connection isn't fully setup.
"""

def __init__(
Expand All @@ -48,7 +49,14 @@ def __init__(
connection_name: Optional[str] = None,
):
self.session = session or bpd.get_global_session()
self.connection_name = connection_name or self.session._bq_connection

connection_name = connection_name or self.session._bq_connection
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this a bit, we are resolving between 3 potential sources of connection:

  1. provided by the user to llm, i.e. connection_name arg
  2. defined by the user in the session, i.e. self.session._bq_connection
  3. the default defined in clients.py

How about we set the default in Session.__init__ like below?

self._bq_connection = context.bq_connection or "bigframes-default-connection"

We are assuming other session defaults there, such as self._location = "US".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yea, I did this way at first place, but then I found remote_function doesn't always have a session object. So I had to put the default away from session. Then we added default session...

We can move it to session.

self.connection_name = clients.get_connection_name_full(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
)

self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
Expand Down Expand Up @@ -180,7 +188,14 @@ def __init__(
connection_name: Optional[str] = None,
):
self.session = session or bpd.get_global_session()
self.connection_name = connection_name or self.session._bq_connection

connection_name = connection_name or self.session._bq_connection
self.connection_name = clients.get_connection_name_full(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
)

self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
Expand Down
74 changes: 30 additions & 44 deletions 74 bigframes/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,9 +695,12 @@ def remote_function(
persistent name.

"""
shobsi marked this conversation as resolved.
Show resolved Hide resolved
import bigframes.pandas as bpd

session = session or bpd.get_global_session()

# A BigQuery client is required to perform BQ operations
if not bigquery_client and session:
if not bigquery_client:
bigquery_client = session.bqclient
if not bigquery_client:
raise ValueError(
Expand All @@ -706,7 +709,7 @@ def remote_function(
)

# A BigQuery connection client is required to perform BQ connection operations
if not bigquery_connection_client and session:
if not bigquery_connection_client:
bigquery_connection_client = session.bqconnectionclient
if not bigquery_connection_client:
raise ValueError(
Expand All @@ -716,8 +719,7 @@ def remote_function(

# A cloud functions client is required to perform cloud functions operations
if not cloud_functions_client:
if session:
cloud_functions_client = session.cloudfunctionsclient
cloud_functions_client = session.cloudfunctionsclient
if not cloud_functions_client:
raise ValueError(
"A cloud functions client must be provided, either directly or via session. "
Expand All @@ -726,8 +728,7 @@ def remote_function(

# A resource manager client is required to get/set IAM operations
if not resource_manager_client:
if session:
resource_manager_client = session.resourcemanagerclient
resource_manager_client = session.resourcemanagerclient
if not resource_manager_client:
raise ValueError(
"A resource manager client must be provided, either directly or via session. "
Expand All @@ -740,56 +741,41 @@ def remote_function(
dataset_ref = bigquery.DatasetReference.from_string(
dataset, default_project=bigquery_client.project
)
elif session:
else:
dataset_ref = bigquery.DatasetReference.from_string(
session._session_dataset_id, default_project=bigquery_client.project
)
else:
raise ValueError(
"Project and dataset must be provided, either directly or via session. "
f"{constants.FEEDBACK_LINK}"
)

bq_location, cloud_function_region = get_remote_function_locations(
bigquery_client.location
)

# A connection is required for BQ remote function
# https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function
if not bigquery_connection and session:
bigquery_connection = session._bq_connection # type: ignore
if not bigquery_connection:
bigquery_connection = session._bq_connection # type: ignore

bigquery_connection = clients.get_connection_name_full(
bigquery_connection,
default_project=dataset_ref.project,
default_location=bq_location,
)
# Guaranteed to be the form of <project>.<location>.<connection_id>
(
gcp_project_id,
bq_connection_location,
bq_connection_id,
) = bigquery_connection.split(".")
if gcp_project_id.casefold() != dataset_ref.project.casefold():
raise ValueError(
"BigQuery connection must be provided, either directly or via session. "
f"{constants.FEEDBACK_LINK}"
"The project_id does not match BigQuery connection gcp_project_id: "
f"{dataset_ref.project}."
)
if bq_connection_location.casefold() != bq_location.casefold():
raise ValueError(
"The location does not match BigQuery connection location: "
f"{bq_location}."
)

# Check connection_id with `LOCATION.CONNECTION_ID` or `PROJECT_ID.LOCATION.CONNECTION_ID` format.
if bigquery_connection.count(".") == 1:
bq_connection_location, bq_connection_id = bigquery_connection.split(".")
if bq_connection_location.casefold() != bq_location.casefold():
raise ValueError(
"The location does not match BigQuery connection location: "
f"{bq_location}."
)
bigquery_connection = bq_connection_id
elif bigquery_connection.count(".") == 2:
(
gcp_project_id,
bq_connection_location,
bq_connection_id,
) = bigquery_connection.split(".")
if gcp_project_id.casefold() != dataset_ref.project.casefold():
raise ValueError(
"The project_id does not match BigQuery connection gcp_project_id: "
f"{dataset_ref.project}."
)
if bq_connection_location.casefold() != bq_location.casefold():
raise ValueError(
"The location does not match BigQuery connection location: "
f"{bq_location}."
)
bigquery_connection = bq_connection_id

def wrapper(f):
if not callable(f):
Expand All @@ -808,7 +794,7 @@ def wrapper(f):
dataset_ref.dataset_id,
bigquery_client,
bigquery_connection_client,
bigquery_connection,
bq_connection_id,
resource_manager_client,
)

Expand Down
6 changes: 5 additions & 1 deletion 6 bigframes/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,14 @@ def resourcemanagerclient(self):
@property
def _session_dataset_id(self):
"""A dataset for storing temporary objects local to the session
This is a workaround for BQML models and remote functions that do not
This is a workaround for remote functions that do not
yet support session-temporary instances."""
return self._session_dataset.dataset_id

@property
def _project(self):
return self.bqclient.project

def _create_and_bind_bq_session(self):
"""Create a BQ session and bind the session id with clients to capture BQ activities:
go/bigframes-transient-data"""
Expand Down
37 changes: 36 additions & 1 deletion 37 tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def test_create_text_generator_model(palm2_text_generator_model):
assert palm2_text_generator_model._bqml_model is not None


def test_create_text_generator_model_defaults(bq_connection):
@pytest.mark.flaky(retries=2, delay=120)
def test_create_text_generator_model_default_session(bq_connection, llm_text_pandas_df):
import bigframes.pandas as bpd

bpd.reset_session()
Expand All @@ -36,6 +37,40 @@ def test_create_text_generator_model_defaults(bq_connection):
model = llm.PaLM2TextGenerator()
assert model is not None
assert model._bqml_model is not None
assert model.connection_name.casefold() == "bigframes-dev.us.bigframes-rf-conn"

llm_text_df = bpd.read_pandas(llm_text_pandas_df)

df = model.predict(llm_text_df).to_pandas()
TestCase().assertSequenceEqual(df.shape, (3, 1))
assert "ml_generate_text_llm_result" in df.columns
series = df["ml_generate_text_llm_result"]
assert all(series.str.len() > 20)


@pytest.mark.flaky(retries=2, delay=120)
def test_create_text_generator_model_default_connection(llm_text_pandas_df):
from bigframes import _config
import bigframes.pandas as bpd

bpd.reset_session()
_config.options = _config.Options() # reset configs

llm_text_df = bpd.read_pandas(llm_text_pandas_df)

model = llm.PaLM2TextGenerator()
assert model is not None
assert model._bqml_model is not None
assert (
model.connection_name.casefold()
== "bigframes-dev.us.bigframes-default-connection"
)

df = model.predict(llm_text_df).to_pandas()
TestCase().assertSequenceEqual(df.shape, (3, 1))
assert "ml_generate_text_llm_result" in df.columns
series = df["ml_generate_text_llm_result"]
assert all(series.str.len() > 20)


# Marked as flaky only because BQML LLM is in preview, the service only has limited capacity, not stable enough.
Expand Down
31 changes: 31 additions & 0 deletions 31 tests/system/small/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import bigframes
from bigframes import remote_function as rf
import bigframes.pandas as bpd
from tests.system.utils import assert_pandas_df_equal_ignore_ordering


Expand Down Expand Up @@ -465,6 +466,36 @@ def square(x):
assert_pandas_df_equal_ignore_ordering(bf_result, pd_result)


@pytest.mark.flaky(retries=2, delay=120)
def test_remote_function_default_connection(scalars_dfs, dataset_id):
@bpd.remote_function([int], int, dataset=dataset_id)
def square(x):
return x * x

scalars_df, scalars_pandas_df = scalars_dfs

bf_int64_col = scalars_df["int64_col"]
bf_int64_col_filter = bf_int64_col.notnull()
bf_int64_col_filtered = bf_int64_col[bf_int64_col_filter]
bf_result_col = bf_int64_col_filtered.apply(square)
bf_result = (
bf_int64_col_filtered.to_frame().assign(result=bf_result_col).to_pandas()
)

pd_int64_col = scalars_pandas_df["int64_col"]
pd_int64_col_filter = pd_int64_col.notnull()
pd_int64_col_filtered = pd_int64_col[pd_int64_col_filter]
pd_result_col = pd_int64_col_filtered.apply(lambda x: x * x)
# TODO(shobs): Figure why pandas .apply() changes the dtype, i.e.
# pd_int64_col_filtered.dtype is Int64Dtype()
# pd_int64_col_filtered.apply(lambda x: x * x).dtype is int64.
# For this test let's force the pandas dtype to be same as bigframes' dtype.
pd_result_col = pd_result_col.astype(pd.Int64Dtype())
pd_result = pd_int64_col_filtered.to_frame().assign(result=pd_result_col)

assert_pandas_df_equal_ignore_ordering(bf_result, pd_result)


@pytest.mark.flaky(retries=2, delay=120)
def test_dataframe_applymap(session_with_bq_connection, scalars_dfs):
def add_one(x):
Expand Down
57 changes: 57 additions & 0 deletions 57 tests/unit/test_clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from bigframes import clients


def test_get_connection_name_full_none():
connection_name = clients.get_connection_name_full(
None, default_project="default-project", default_location="us"
)
assert connection_name == "default-project.us.bigframes-default-connection"


def test_get_connection_name_full_connection_id():
connection_name = clients.get_connection_name_full(
"connection-id", default_project="default-project", default_location="us"
)
assert connection_name == "default-project.us.connection-id"


def test_get_connection_name_full_location_connection_id():
connection_name = clients.get_connection_name_full(
"eu.connection-id", default_project="default-project", default_location="us"
)
assert connection_name == "default-project.eu.connection-id"


def test_get_connection_name_full_all():
connection_name = clients.get_connection_name_full(
"my-project.eu.connection-id",
default_project="default-project",
default_location="us",
)
assert connection_name == "my-project.eu.connection-id"


def test_get_connection_name_full_raise_value_error():

with pytest.raises(ValueError):
clients.get_connection_name_full(
"my-project.eu.connection-id.extra_field",
default_project="default-project",
default_location="us",
)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.