diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index fb4e3f2f36..7be252406c 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -14,6 +14,7 @@ from __future__ import annotations +import collections.abc import hashlib import inspect import logging @@ -1043,6 +1044,8 @@ def wrapper(func): "Types are required to use @remote_function." ) input_types.append(param_type) + elif not isinstance(input_types, collections.abc.Sequence): + input_types = [input_types] if output_type is None: if (output_type := signature.return_annotation) is inspect.Signature.empty: @@ -1055,9 +1058,12 @@ def wrapper(func): # The function will actually be receiving a pandas Series, but allow both # BigQuery DataFrames and pandas object types for compatibility. is_row_processor = False - if input_types == bigframes.series.Series or input_types == pandas.Series: + if len(input_types) == 1 and ( + (input_type := input_types[0]) == bigframes.series.Series + or input_type == pandas.Series + ): warnings.warn( - "input_types=Series scenario is in preview.", + "input_types=Series is in preview.", stacklevel=1, category=bigframes.exceptions.PreviewWarning, ) diff --git a/tests/system/small/test_remote_function.py b/tests/system/small/test_remote_function.py index 096a268441..d2ee4411f4 100644 --- a/tests/system/small/test_remote_function.py +++ b/tests/system/small/test_remote_function.py @@ -727,7 +727,7 @@ def add_ints(row): with pytest.warns( bigframes.exceptions.PreviewWarning, - match="input_types=Series scenario is in preview.", + match="input_types=Series is in preview.", ): add_ints_remote = session.remote_function(bigframes.series.Series, int)( add_ints diff --git a/tests/unit/resources.py b/tests/unit/resources.py index 4d7998903c..84699459e6 100644 --- a/tests/unit/resources.py +++ b/tests/unit/resources.py @@ -23,6 +23,7 @@ import pytest import bigframes +import bigframes.clients import bigframes.core as core import bigframes.core.ordering import bigframes.dataframe @@ -97,6 +98,9 @@ def query_mock(query, *args, **kwargs): bqoptions = bigframes.BigQueryOptions(credentials=credentials, location=location) session = bigframes.Session(context=bqoptions, clients_provider=clients_provider) + session._bq_connection_manager = mock.create_autospec( + bigframes.clients.BqConnectionManager, instance=True + ) return session diff --git a/tests/unit/test_remote_function.py b/tests/unit/test_remote_function.py index ae9ab296c5..1bd3f3b14f 100644 --- a/tests/unit/test_remote_function.py +++ b/tests/unit/test_remote_function.py @@ -12,15 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + import bigframes_vendored.ibis.backends.bigquery.datatypes as third_party_ibis_bqtypes from ibis.expr import datatypes as ibis_types +import pandas import pytest import bigframes.dtypes import bigframes.functions.remote_function +import bigframes.series from tests.unit import resources +@pytest.mark.parametrize( + "series_type", + ( + pytest.param( + pandas.Series, + id="pandas.Series", + ), + pytest.param( + bigframes.series.Series, + id="bigframes.series.Series", + ), + ), +) +def test_series_input_types_to_str(series_type): + """Check that is_row_processor=True uses str as the input type to serialize a row.""" + session = resources.create_bigquery_session() + remote_function_decorator = bigframes.functions.remote_function.remote_function( + session=session + ) + + with pytest.warns( + bigframes.exceptions.PreviewWarning, + match=re.escape("input_types=Series is in preview."), + ): + + @remote_function_decorator + def axis_1_function(myparam: series_type) -> str: # type: ignore + return "Hello, " + myparam["str_col"] + "!" # type: ignore + + # Still works as a normal function. + assert axis_1_function(pandas.Series({"str_col": "World"})) == "Hello, World!" + assert axis_1_function.ibis_node is not None + + def test_supported_types_correspond(): # The same types should be representable by the supported Python and BigQuery types. ibis_types_from_python = { diff --git a/third_party/bigframes_vendored/pandas/core/series.py b/third_party/bigframes_vendored/pandas/core/series.py index 6a7a815ed9..a430c3375f 100644 --- a/third_party/bigframes_vendored/pandas/core/series.py +++ b/third_party/bigframes_vendored/pandas/core/series.py @@ -3969,7 +3969,7 @@ def map( It also accepts a remote function: - >>> @bpd.remote_function + >>> @bpd.remote_function() ... def my_mapper(val: str) -> str: ... vowels = ["a", "e", "i", "o", "u"] ... if val: