diff --git a/README.rst b/README.rst index 0a426ac3..ca98bba0 100644 --- a/README.rst +++ b/README.rst @@ -234,6 +234,25 @@ To create the base64 encoded string you can use the command line tool ``base64`` Alternatively, you can use an online generator like `www.base64encode.org _` to paste your credentials JSON file to be encoded. + +Supplying Your Own BigQuery Client +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The above connection string parameters allow you to influence how the BigQuery client used to execute your queries will be instantiated. +If you need additional control, you can supply a BigQuery client of your own: + +.. code-block:: python + + from google.cloud import bigquery + + custom_bq_client = bigquery.Client(...) + + engine = create_engine( + 'bigquery://some-project/some-dataset?user_supplied_client=True', + connect_args={'client': custom_bq_client}, + ) + + Creating tables ^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 48455836..3a26b330 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -801,6 +801,7 @@ def create_connect_args(self, url): credentials_base64, default_query_job_config, list_tables_page_size, + user_supplied_client, ) = parse_url(url) self.arraysize = arraysize or self.arraysize @@ -812,15 +813,21 @@ def create_connect_args(self, url): self._add_default_dataset_to_job_config( default_query_job_config, project_id, dataset_id ) - client = _helpers.create_bigquery_client( - credentials_path=self.credentials_path, - credentials_info=self.credentials_info, - credentials_base64=self.credentials_base64, - project_id=project_id, - location=self.location, - default_query_job_config=default_query_job_config, - ) - return ([client], {}) + + if user_supplied_client: + # The user is expected to supply a client with + # create_engine('...', connect_args={'client': bq_client}) + return ([], {}) + else: + client = _helpers.create_bigquery_client( + credentials_path=self.credentials_path, + credentials_info=self.credentials_info, + credentials_base64=self.credentials_base64, + project_id=project_id, + location=self.location, + default_query_job_config=default_query_job_config, + ) + return ([], {"client": client}) def _get_table_or_view_names(self, connection, item_types, schema=None): current_schema = schema or self.dataset_id diff --git a/sqlalchemy_bigquery/parse_url.py b/sqlalchemy_bigquery/parse_url.py index b1d4b589..7bf6d415 100644 --- a/sqlalchemy_bigquery/parse_url.py +++ b/sqlalchemy_bigquery/parse_url.py @@ -70,6 +70,7 @@ def parse_url(url): # noqa: C901 credentials_path = None credentials_base64 = None list_tables_page_size = None + user_supplied_client = False # location if "location" in query: @@ -101,6 +102,10 @@ def parse_url(url): # noqa: C901 + str_list_tables_page_size ) + # user_supplied_client + if "user_supplied_client" in query: + user_supplied_client = query.pop("user_supplied_client").lower() == "true" + # if only these "non-config" values were present, the dict will now be empty if not query: # if a dataset_id exists, we need to return a job_config that isn't None @@ -115,6 +120,7 @@ def parse_url(url): # noqa: C901 credentials_base64, QueryJobConfig(), list_tables_page_size, + user_supplied_client, ) else: return ( @@ -126,6 +132,7 @@ def parse_url(url): # noqa: C901 credentials_base64, None, list_tables_page_size, + user_supplied_client, ) job_config = QueryJobConfig() @@ -275,4 +282,5 @@ def parse_url(url): # noqa: C901 credentials_base64, job_config, list_tables_page_size, + user_supplied_client, ) diff --git a/tests/unit/test_parse_url.py b/tests/unit/test_parse_url.py index 9f080933..8c0274d2 100644 --- a/tests/unit/test_parse_url.py +++ b/tests/unit/test_parse_url.py @@ -63,6 +63,7 @@ def url_with_everything(): "&schema_update_options=ALLOW_FIELD_ADDITION,ALLOW_FIELD_RELAXATION" "&use_query_cache=true" "&write_disposition=WRITE_APPEND" + "&user_supplied_client=true" ) @@ -76,6 +77,7 @@ def test_basic(url_with_everything): credentials_base64, job_config, list_tables_page_size, + user_supplied_client, ) = parse_url(url_with_everything) assert project_id == "some-project" @@ -86,6 +88,7 @@ def test_basic(url_with_everything): assert credentials_path == "/some/path/to.json" assert credentials_base64 == "eyJrZXkiOiJ2YWx1ZSJ9Cg==" assert isinstance(job_config, QueryJobConfig) + assert user_supplied_client @pytest.mark.parametrize( @@ -161,11 +164,15 @@ def test_bad_values(param, value): def test_empty_url(): - for value in parse_url(make_url("bigquery://")): + values = parse_url(make_url("bigquery://")) + for value in values[:-1]: assert value is None + assert not values[-1] - for value in parse_url(make_url("bigquery:///")): + values = parse_url(make_url("bigquery:///")) + for value in values[:-1]: assert value is None + assert not values[-1] def test_empty_with_non_config(): @@ -183,6 +190,7 @@ def test_empty_with_non_config(): credentials_base64, job_config, list_tables_page_size, + user_supplied_credentials, ) = url assert project_id is None @@ -193,6 +201,7 @@ def test_empty_with_non_config(): assert credentials_base64 is None assert job_config is None assert list_tables_page_size is None + assert not user_supplied_credentials def test_only_dataset(): @@ -206,6 +215,7 @@ def test_only_dataset(): credentials_base64, job_config, list_tables_page_size, + user_supplied_credentials, ) = url assert project_id is None @@ -216,6 +226,7 @@ def test_only_dataset(): assert credentials_base64 is None assert list_tables_page_size is None assert isinstance(job_config, QueryJobConfig) + assert not user_supplied_credentials # we can't actually test that the dataset is on the job_config, # since we take care of that afterwards, when we have a client to fill in the project diff --git a/tests/unit/test_sqlalchemy_bigquery.py b/tests/unit/test_sqlalchemy_bigquery.py index 53c49bf5..06ef79d2 100644 --- a/tests/unit/test_sqlalchemy_bigquery.py +++ b/tests/unit/test_sqlalchemy_bigquery.py @@ -233,3 +233,16 @@ def test_unnest_function(args, kw): assert isinstance( sqlalchemy.select([f]).subquery().c.unnest.type, sqlalchemy.String ) + + +@mock.patch("sqlalchemy_bigquery._helpers.create_bigquery_client") +def test_setting_user_supplied_client_skips_creating_client( + mock_create_bigquery_client, +): + import sqlalchemy_bigquery # noqa + + result = sqlalchemy_bigquery.BigQueryDialect().create_connect_args( + mock.MagicMock(database=None, query={"user_supplied_client": "true"}) + ) + assert result == ([], {}) + assert not mock_create_bigquery_client.called