diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index 4bb667ccc7..f7237c564c 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -24,7 +24,7 @@ import sys import tempfile import textwrap -from typing import List, NamedTuple, Optional, Sequence, TYPE_CHECKING +from typing import List, NamedTuple, Optional, Sequence, TYPE_CHECKING, Union import ibis import requests @@ -623,7 +623,7 @@ def get_routine_reference( # which has moved as @js to the ibis package # https://github.com/ibis-project/ibis/blob/master/ibis/backends/bigquery/udf/__init__.py def remote_function( - input_types: Sequence[type], + input_types: Union[type, Sequence[type]], output_type: type, session: Optional[Session] = None, bigquery_client: Optional[bigquery.Client] = None, @@ -686,9 +686,10 @@ def remote_function( `$ gcloud projects add-iam-policy-binding PROJECT_ID --member="serviceAccount:CONNECTION_SERVICE_ACCOUNT_ID" --role="roles/run.invoker"`. Args: - input_types list(type): - List of input data types in the user defined function. - output_type type: + input_types (type or sequence(type)): + Input data type, or sequence of input data types in the user + defined function. + output_type (type): Data type of the output in the user defined function. session (bigframes.Session, Optional): BigQuery DataFrames session to use for getting default project, @@ -778,6 +779,9 @@ def remote_function( By default BigQuery DataFrames uses a 10 minute timeout. `None` can be passed to let the cloud functions default timeout take effect. """ + if isinstance(input_types, type): + input_types = [input_types] + import bigframes.pandas as bpd session = session or bpd.get_global_session() diff --git a/bigframes/pandas/__init__.py b/bigframes/pandas/__init__.py index 71ef4e609e..48a4b0f68d 100644 --- a/bigframes/pandas/__init__.py +++ b/bigframes/pandas/__init__.py @@ -633,7 +633,7 @@ def read_parquet( def remote_function( - input_types: List[type], + input_types: Union[type, Sequence[type]], output_type: type, dataset: Optional[str] = None, bigquery_connection: Optional[str] = None, diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 34047ff155..79febcc5d9 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -1538,7 +1538,7 @@ def _ibis_to_temp_table( def remote_function( self, - input_types: List[type], + input_types: Union[type, Sequence[type]], output_type: type, dataset: Optional[str] = None, bigquery_connection: Optional[str] = None, @@ -1592,8 +1592,9 @@ def remote_function( `$ gcloud projects add-iam-policy-binding PROJECT_ID --member="serviceAccount:CONNECTION_SERVICE_ACCOUNT_ID" --role="roles/run.invoker"`. Args: - input_types (list(type)): - List of input data types in the user defined function. + input_types (type or sequence(type)): + Input data type, or sequence of input data types in the user + defined function. output_type (type): Data type of the output in the user defined function. dataset (str, Optional): diff --git a/samples/snippets/remote_function.py b/samples/snippets/remote_function.py index 61b7dc092a..4db4e67619 100644 --- a/samples/snippets/remote_function.py +++ b/samples/snippets/remote_function.py @@ -47,7 +47,7 @@ def run_remote_function_and_read_gbq_function(project_id: str): # of the penguins, which is a real number, into a category, which is a # string. @bpd.remote_function( - [float], + float, str, reuse=False, ) @@ -91,7 +91,7 @@ def get_bucket(num): # as a remote function. The custom function in this example has external # package dependency, which can be specified via `packages` parameter. @bpd.remote_function( - [str], + str, str, reuse=False, packages=["cryptography"], diff --git a/tests/system/large/test_remote_function.py b/tests/system/large/test_remote_function.py index 6cae893f9c..eb7cb8308b 100644 --- a/tests/system/large/test_remote_function.py +++ b/tests/system/large/test_remote_function.py @@ -310,6 +310,35 @@ def add_one(x): ) +@pytest.mark.parametrize( + ("input_types"), + [ + pytest.param([int], id="list-of-int"), + pytest.param(int, id="int"), + ], +) +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_input_types(session, scalars_dfs, input_types): + try: + + def add_one(x): + return x + 1 + + remote_add_one = session.remote_function(input_types, int)(add_one) + + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df.int64_too.map(remote_add_one).to_pandas() + pd_result = scalars_pandas_df.int64_too.map(add_one) + + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + finally: + # clean up the gcp assets created for the remote function + cleanup_remote_function_assets( + session.bqclient, session.cloudfunctionsclient, remote_add_one + ) + + @pytest.mark.flaky(retries=2, delay=120) def test_remote_function_explicit_dataset_not_created( session, diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 1669a291c9..c5168cd160 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -3892,7 +3892,7 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame: to potentially reuse a previously deployed ``remote_function`` from the same user defined function. - >>> @bpd.remote_function([int], float, reuse=False) + >>> @bpd.remote_function(int, float, reuse=False) ... def minutes_to_hours(x): ... return x/60 diff --git a/third_party/bigframes_vendored/pandas/core/series.py b/third_party/bigframes_vendored/pandas/core/series.py index 0c5b8d4521..4833c41ff7 100644 --- a/third_party/bigframes_vendored/pandas/core/series.py +++ b/third_party/bigframes_vendored/pandas/core/series.py @@ -1181,7 +1181,7 @@ def apply( to potentially reuse a previously deployed `remote_function` from the same user defined function. - >>> @bpd.remote_function([int], float, reuse=False) + >>> @bpd.remote_function(int, float, reuse=False) ... def minutes_to_hours(x): ... return x/60 @@ -1208,7 +1208,7 @@ def apply( `packages` param. >>> @bpd.remote_function( - ... [str], + ... str, ... str, ... reuse=False, ... packages=["cryptography"], @@ -3341,7 +3341,7 @@ def mask(self, cond, other): condition is evaluated based on a complicated business logic which cannot be expressed in form of a Series. - >>> @bpd.remote_function([str], bool, reuse=False) + >>> @bpd.remote_function(str, bool, reuse=False) ... def should_mask(name): ... hash = 0 ... for char_ in name: @@ -3860,7 +3860,7 @@ def map( It also accepts a remote function: - >>> @bpd.remote_function([str], str) + >>> @bpd.remote_function(str, str) ... def my_mapper(val): ... vowels = ["a", "e", "i", "o", "u"] ... if val: