diff --git a/bigframes/pandas/__init__.py b/bigframes/pandas/__init__.py index 0c2c1f87aa..c9640abb23 100644 --- a/bigframes/pandas/__init__.py +++ b/bigframes/pandas/__init__.py @@ -59,6 +59,7 @@ import third_party.bigframes_vendored.pandas.core.reshape.encoding as vendored_pandas_encoding import third_party.bigframes_vendored.pandas.core.reshape.merge as vendored_pandas_merge import third_party.bigframes_vendored.pandas.core.reshape.tile as vendored_pandas_tile +import third_party.bigframes_vendored.pandas.io.gbq as vendored_pandas_gbq # Include method definition so that the method appears in our docs for @@ -486,6 +487,7 @@ def read_gbq( index_col: Iterable[str] | str = (), col_order: Iterable[str] = (), max_results: Optional[int] = None, + filters: vendored_pandas_gbq.FiltersType = (), use_cache: bool = True, ) -> bigframes.dataframe.DataFrame: _set_default_session_location_if_possible(query_or_table) @@ -495,6 +497,7 @@ def read_gbq( index_col=index_col, col_order=col_order, max_results=max_results, + filters=filters, use_cache=use_cache, ) diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index cebef532ad..5364060d1c 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -233,10 +233,13 @@ def read_gbq( index_col: Iterable[str] | str = (), col_order: Iterable[str] = (), max_results: Optional[int] = None, + filters: third_party_pandas_gbq.FiltersType = (), use_cache: bool = True, # Add a verify index argument that fails if the index is not unique. ) -> dataframe.DataFrame: # TODO(b/281571214): Generate prompt to show the progress of read_gbq. + query_or_table = self._filters_to_query(query_or_table, col_order, filters) + if _is_query(query_or_table): return self._read_gbq_query( query_or_table, @@ -259,6 +262,80 @@ def read_gbq( use_cache=use_cache, ) + def _filters_to_query(self, query_or_table, columns, filters): + """Convert filters to query""" + if len(filters) == 0: + return query_or_table + + sub_query = ( + f"({query_or_table})" if _is_query(query_or_table) else query_or_table + ) + + select_clause = "SELECT " + ( + ", ".join(f"`{column}`" for column in columns) if columns else "*" + ) + + where_clause = "" + if filters: + valid_operators = { + "in": "IN", + "not in": "NOT IN", + "==": "=", + ">": ">", + "<": "<", + ">=": ">=", + "<=": "<=", + "!=": "!=", + } + + if ( + isinstance(filters, Iterable) + and isinstance(filters[0], Tuple) + and (len(filters[0]) == 0 or not isinstance(filters[0][0], Tuple)) + ): + filters = [filters] + + or_expressions = [] + for group in filters: + if not isinstance(group, Iterable): + raise ValueError( + f"Filter group should be a iterable, {group} is not valid." + ) + + and_expressions = [] + for filter_item in group: + if not isinstance(filter_item, tuple) or (len(filter_item) != 3): + raise ValueError( + f"Filter condition should be a tuple of length 3, {filter_item} is not valid." + ) + + column, operator, value = filter_item + + if not isinstance(column, str): + raise ValueError( + f"Column name should be a string, but received '{column}' of type {type(column).__name__}." + ) + + if operator not in valid_operators: + raise ValueError(f"Operator {operator} is not valid.") + + operator = valid_operators[operator] + + if operator in ["IN", "NOT IN"]: + value_list = ", ".join([repr(v) for v in value]) + expression = f"`{column}` {operator} ({value_list})" + else: + expression = f"`{column}` {operator} {repr(value)}" + and_expressions.append(expression) + + or_expressions.append(" AND ".join(and_expressions)) + + if or_expressions: + where_clause = " WHERE " + " OR ".join(or_expressions) + + full_query = f"{select_clause} FROM {sub_query} AS sub{where_clause}" + return full_query + def _query_to_destination( self, query: str, diff --git a/tests/unit/session/test_session.py b/tests/unit/session/test_session.py index 18fd42e0f3..d38a393f27 100644 --- a/tests/unit/session/test_session.py +++ b/tests/unit/session/test_session.py @@ -57,3 +57,60 @@ def test_session_init_fails_with_no_project(): credentials=mock.Mock(spec=google.auth.credentials.Credentials) ) ) + + +@pytest.mark.parametrize( + ("query_or_table", "columns", "filters", "expected_output"), + [ + pytest.param( + """SELECT + rowindex, + string_col, + FROM `test_table` AS t + """, + [], + [("rowindex", "<", 4), ("string_col", "==", "Hello, World!")], + """SELECT * FROM (SELECT + rowindex, + string_col, + FROM `test_table` AS t + ) AS sub WHERE `rowindex` < 4 AND `string_col` = 'Hello, World!'""", + id="query_input", + ), + pytest.param( + "test_table", + [], + [("date_col", ">", "2022-10-20")], + "SELECT * FROM test_table AS sub WHERE `date_col` > '2022-10-20'", + id="table_input", + ), + pytest.param( + "test_table", + ["row_index", "string_col"], + [ + (("rowindex", "not in", [0, 6]),), + (("string_col", "in", ["Hello, World!", "こんにちは"]),), + ], + ( + "SELECT `row_index`, `string_col` FROM test_table AS sub WHERE " + "`rowindex` NOT IN (0, 6) OR `string_col` IN ('Hello, World!', " + "'こんにちは')" + ), + id="or_operation", + ), + pytest.param( + "test_table", + [], + ["date_col", ">", "2022-10-20"], + None, + marks=pytest.mark.xfail( + raises=ValueError, + ), + id="raise_error", + ), + ], +) +def test_read_gbq_with_filters(query_or_table, columns, filters, expected_output): + session = resources.create_bigquery_session() + query = session._filters_to_query(query_or_table, columns, filters) + assert query == expected_output diff --git a/third_party/bigframes_vendored/pandas/io/gbq.py b/third_party/bigframes_vendored/pandas/io/gbq.py index eabb48e600..dc8bcc1f77 100644 --- a/third_party/bigframes_vendored/pandas/io/gbq.py +++ b/third_party/bigframes_vendored/pandas/io/gbq.py @@ -3,10 +3,13 @@ from __future__ import annotations -from typing import Iterable, Optional +from typing import Any, Iterable, Literal, Optional, Tuple, Union from bigframes import constants +FilterType = Tuple[str, Literal["in", "not in", "<", "<=", "==", "!=", ">=", ">"], Any] +FiltersType = Iterable[Union[FilterType, Iterable[FilterType]]] + class GBQIOMixin: def read_gbq( @@ -16,6 +19,7 @@ def read_gbq( index_col: Iterable[str] | str = (), col_order: Iterable[str] = (), max_results: Optional[int] = None, + filters: FiltersType = (), use_cache: bool = True, ): """Loads a DataFrame from BigQuery. @@ -71,6 +75,21 @@ def read_gbq( [2 rows x 3 columns] + Reading data with `columns` and `filters` parameters: + + >>> col_order = ['pitcherFirstName', 'pitcherLastName', 'year', 'pitchSpeed'] + >>> filters = [('year', '==', 2016), ('pitcherFirstName', 'in', ['John', 'Doe']), ('pitcherLastName', 'in', ['Gant'])] + >>> df = bpd.read_gbq( + ... "bigquery-public-data.baseball.games_wide", + ... col_order=col_order, + ... filters=filters, + ... ) + >>> df.head(1) + pitcherFirstName pitcherLastName year pitchSpeed + 0 John Gant 2016 82 + + [1 rows x 4 columns] + Args: query_or_table (str): A SQL string to be executed or a BigQuery table to be read. The @@ -84,6 +103,14 @@ def read_gbq( max_results (Optional[int], default None): If set, limit the maximum number of rows to fetch from the query results. + filters (Iterable[Union[Tuple, Iterable[Tuple]]], default ()): To + filter out data. Filter syntax: [[(column, op, val), …],…] where + op is [==, >, >=, <, <=, !=, in, not in]. The innermost tuples + are transposed into a set of filters applied through an AND + operation. The outer Iterable combines these sets of filters + through an OR operation. A single Iterable of tuples can also + be used, meaning that no OR operation between set of filters + is to be conducted. use_cache (bool, default True): Whether to cache the query inputs. Default to True.