From 4b562bc51c602e607a83482bdcf4f5d3046317b9 Mon Sep 17 00:00:00 2001 From: Paul Rhodes Date: Thu, 10 Apr 2025 13:13:02 +0100 Subject: [PATCH 1/2] feat: Allow jobs to be run in a different project --- sqlalchemy_bigquery/base.py | 48 +++++--- .../system/test_sqlalchemy_bigquery_remote.py | 107 ++++++++++++++++++ tests/unit/fauxdbi.py | 21 ++-- tests/unit/test_engine.py | 6 + 4 files changed, 157 insertions(+), 25 deletions(-) create mode 100644 tests/system/test_sqlalchemy_bigquery_remote.py diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 0204bc92..4008a7e1 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -27,7 +27,7 @@ from google import auth import google.api_core.exceptions -from google.cloud.bigquery import dbapi +from google.cloud.bigquery import dbapi, ConnectionProperty from google.cloud.bigquery.table import ( RangePartitioning, TableReference, @@ -61,6 +61,7 @@ from .parse_url import parse_url from . import _helpers, _struct, _types import sqlalchemy_bigquery_vendored.sqlalchemy.postgresql.base as vendored_postgresql +from google.cloud.bigquery import QueryJobConfig # Illegal characters is intended to be all characters that are not explicitly # allowed as part of the flexible column names. @@ -1080,6 +1081,7 @@ def __init__( self, arraysize=5000, credentials_path=None, + billing_project_id=None, location=None, credentials_info=None, credentials_base64=None, @@ -1092,6 +1094,8 @@ def __init__( self.credentials_path = credentials_path self.credentials_info = credentials_info self.credentials_base64 = credentials_base64 + self.project_id = None + self.billing_project_id = billing_project_id self.location = location self.identifier_preparer = self.preparer(self) self.dataset_id = None @@ -1114,15 +1118,20 @@ def _build_formatted_table_id(table): """Build '.' string using given table.""" return "{}.{}".format(table.reference.dataset_id, table.table_id) - @staticmethod - def _add_default_dataset_to_job_config(job_config, project_id, dataset_id): - # If dataset_id is set, then we know the job_config isn't None - if dataset_id: - # If project_id is missing, use default project_id for the current environment + def create_job_config(self, provided_config: QueryJobConfig): + project_id = self.project_id + if self.dataset_id is None and project_id == self.billing_project_id: + return provided_config + job_config = provided_config or QueryJobConfig() + if project_id != self.billing_project_id: + job_config.connection_properties = [ + ConnectionProperty(key="dataset_project_id", value=project_id) + ] + if self.dataset_id: if not project_id: _, project_id = auth.default() - - job_config.default_dataset = "{}.{}".format(project_id, dataset_id) + job_config.default_dataset = "{}.{}".format(project_id, self.dataset_id) + return job_config def do_execute(self, cursor, statement, parameters, context=None): kwargs = {} @@ -1132,13 +1141,13 @@ def do_execute(self, cursor, statement, parameters, context=None): def create_connect_args(self, url): ( - project_id, + self.project_id, location, dataset_id, arraysize, credentials_path, credentials_base64, - default_query_job_config, + provided_job_config, list_tables_page_size, user_supplied_client, ) = parse_url(url) @@ -1149,9 +1158,9 @@ def create_connect_args(self, url): self.credentials_path = credentials_path or self.credentials_path self.credentials_base64 = credentials_base64 or self.credentials_base64 self.dataset_id = dataset_id - self._add_default_dataset_to_job_config( - default_query_job_config, project_id, dataset_id - ) + self.billing_project_id = self.billing_project_id or self.project_id + + default_query_job_config = self.create_job_config(provided_job_config) if user_supplied_client: # The user is expected to supply a client with @@ -1162,10 +1171,14 @@ def create_connect_args(self, url): credentials_path=self.credentials_path, credentials_info=self.credentials_info, credentials_base64=self.credentials_base64, - project_id=project_id, + project_id=self.billing_project_id, location=self.location, default_query_job_config=default_query_job_config, ) + # If the user specified `bigquery://` we need to set the project_id + # from the client + self.project_id = self.project_id or client.project + self.billing_project_id = self.billing_project_id or client.project return ([], {"client": client}) def _get_table_or_view_names(self, connection, item_types, schema=None): @@ -1177,7 +1190,7 @@ def _get_table_or_view_names(self, connection, item_types, schema=None): ) client = connection.connection._client - datasets = client.list_datasets() + datasets = client.list_datasets(self.project_id) result = [] for dataset in datasets: @@ -1278,7 +1291,8 @@ def _get_table(self, connection, table_name, schema=None): client = connection.connection._client - table_ref = self._table_reference(schema, table_name, client.project) + # table_ref = self._table_reference(schema, table_name, client.project) + table_ref = self._table_reference(schema, table_name, self.project_id) try: table = client.get_table(table_ref) except NotFound: @@ -1332,7 +1346,7 @@ def get_schema_names(self, connection, **kw): if isinstance(connection, Engine): connection = connection.connect() - datasets = connection.connection._client.list_datasets() + datasets = connection.connection._client.list_datasets(self.project_id) return [d.dataset_id for d in datasets] def get_table_names(self, connection, schema=None, **kw): diff --git a/tests/system/test_sqlalchemy_bigquery_remote.py b/tests/system/test_sqlalchemy_bigquery_remote.py new file mode 100644 index 00000000..6f1c969b --- /dev/null +++ b/tests/system/test_sqlalchemy_bigquery_remote.py @@ -0,0 +1,107 @@ +# Copyright (c) 2017 The sqlalchemy-bigquery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +# -*- coding: utf-8 -*- + +from sqlalchemy.engine import create_engine +from sqlalchemy.exc import DatabaseError +from sqlalchemy.schema import Table, MetaData +import pytest +import sqlalchemy +import google.api_core.exceptions as core_exceptions + + +EXPECTED_STATES = ["AL", "CA", "FL", "KY"] + +REMOTE_TESTS = [ + ("bigquery-public-data", "bigquery-public-data.usa_names.usa_1910_2013"), + ("bigquery-public-data", "usa_names.usa_1910_2013"), + ("bigquery-public-data/usa_names", "bigquery-public-data.usa_names.usa_1910_2013"), + ("bigquery-public-data/usa_names", "usa_1910_2013"), + ("bigquery-public-data/usa_names", "usa_names.usa_1910_2013"), +] + + +@pytest.fixture(scope="session") +def engine_using_remote_dataset(bigquery_client): + engine = create_engine( + "bigquery://bigquery-public-data/usa_names", + billing_project_id=bigquery_client.project, + echo=True, + ) + return engine + + +def test_remote_tables_list(engine_using_remote_dataset): + tables = sqlalchemy.inspect(engine_using_remote_dataset).get_table_names() + assert "usa_1910_2013" in tables + + +@pytest.mark.parametrize( + ["urlpath", "table_name"], + REMOTE_TESTS, + ids=[f"test_engine_remote_sql_{x}" for x in range(len(REMOTE_TESTS))], +) +def test_engine_remote_sql(bigquery_client, urlpath, table_name): + engine = create_engine( + f"bigquery://{urlpath}", billing_project_id=bigquery_client.project, echo=True + ) + with engine.connect() as conn: + rows = conn.execute( + sqlalchemy.text(f"SELECT DISTINCT(state) FROM `{table_name}`") + ).fetchall() + states = set(map(lambda row: row[0], rows)) + assert set(EXPECTED_STATES).issubset(states) + + +@pytest.mark.parametrize( + ["urlpath", "table_name"], + REMOTE_TESTS, + ids=[f"test_engine_remote_table_{x}" for x in range(len(REMOTE_TESTS))], +) +def test_engine_remote_table(bigquery_client, urlpath, table_name): + engine = create_engine( + f"bigquery://{urlpath}", billing_project_id=bigquery_client.project, echo=True + ) + with engine.connect() as conn: + table = Table(table_name, MetaData(), autoload_with=engine) + prepared = sqlalchemy.select( + sqlalchemy.distinct(table.c.state) + ).set_label_style(sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL) + rows = conn.execute(prepared).fetchall() + states = set(map(lambda row: row[0], rows)) + assert set(EXPECTED_STATES).issubset(states) + + +@pytest.mark.parametrize( + ["urlpath", "table_name"], + REMOTE_TESTS, + ids=[f"test_engine_remote_table_fail_{x}" for x in range(len(REMOTE_TESTS))], +) +def test_engine_remote_table_fail(urlpath, table_name): + engine = create_engine(f"bigquery://{urlpath}", echo=True) + with pytest.raises( + (DatabaseError, core_exceptions.Forbidden), match="Access Denied" + ): + with engine.connect() as conn: + table = Table(table_name, MetaData(), autoload_with=engine) + prepared = sqlalchemy.select( + sqlalchemy.distinct(table.c.state) + ).set_label_style(sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL) + _rows = conn.execute(prepared).fetchall() diff --git a/tests/unit/fauxdbi.py b/tests/unit/fauxdbi.py index 4d8f02b6..c1249c09 100644 --- a/tests/unit/fauxdbi.py +++ b/tests/unit/fauxdbi.py @@ -327,10 +327,12 @@ def _fix_pickled(self, row): pickle.loads(v.encode("latin1")) # \x80\x04 is latin-1 encoded prefix for Pickle protocol 4. if isinstance(v, str) and v[:2] == "\x80\x04" and v[-1] == "." - else pickle.loads(base64.b16decode(v)) - # 8004 is base64 encoded prefix for Pickle protocol 4. - if isinstance(v, str) and v[:4] == "8004" and v[-2:] == "2E" - else v + else ( + pickle.loads(base64.b16decode(v)) + # 8004 is base64 encoded prefix for Pickle protocol 4. + if isinstance(v, str) and v[:4] == "8004" and v[-2:] == "2E" + else v + ) ) for d, v in zip(self.description, row) ] @@ -355,7 +357,10 @@ def __getattr__(self, name): class FauxClient: def __init__(self, project_id=None, default_query_job_config=None, *args, **kw): if project_id is None: - if default_query_job_config is not None: + if ( + default_query_job_config is not None + and default_query_job_config.default_dataset + ): project_id = default_query_job_config.default_dataset.project else: project_id = "authproj" # we would still have gotten it from auth. @@ -469,10 +474,10 @@ def get_table(self, table_ref): else: raise google.api_core.exceptions.NotFound(table_ref) - def list_datasets(self): + def list_datasets(self, project="myproject"): return [ - google.cloud.bigquery.Dataset("myproject.mydataset"), - google.cloud.bigquery.Dataset("myproject.yourdataset"), + google.cloud.bigquery.Dataset(f"{project}.mydataset"), + google.cloud.bigquery.Dataset(f"{project}.yourdataset"), ] def list_tables(self, dataset, page_size): diff --git a/tests/unit/test_engine.py b/tests/unit/test_engine.py index 59481baa..67265b5a 100644 --- a/tests/unit/test_engine.py +++ b/tests/unit/test_engine.py @@ -27,6 +27,12 @@ def test_engine_dataset_but_no_project(faux_conn): assert conn.connection._client.project == "authproj" +def test_engine_dataset_with_billing_project(faux_conn): + engine = sqlalchemy.create_engine("bigquery://foo", billing_project_id="bar") + conn = engine.connect() + assert conn.connection._client.project == "bar" + + def test_engine_no_dataset_no_project(faux_conn): engine = sqlalchemy.create_engine("bigquery://") conn = engine.connect() From ac97534a6737c7bc9d60882eef2960a5053038c1 Mon Sep 17 00:00:00 2001 From: Lingqing Gan Date: Tue, 22 Apr 2025 10:40:50 -0700 Subject: [PATCH 2/2] Update test_sqlalchemy_bigquery_remote.py --- tests/system/test_sqlalchemy_bigquery_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/system/test_sqlalchemy_bigquery_remote.py b/tests/system/test_sqlalchemy_bigquery_remote.py index 6f1c969b..eb98feaa 100644 --- a/tests/system/test_sqlalchemy_bigquery_remote.py +++ b/tests/system/test_sqlalchemy_bigquery_remote.py @@ -104,4 +104,4 @@ def test_engine_remote_table_fail(urlpath, table_name): prepared = sqlalchemy.select( sqlalchemy.distinct(table.c.state) ).set_label_style(sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL) - _rows = conn.execute(prepared).fetchall() + conn.execute(prepared).fetchall()