-
Notifications
You must be signed in to change notification settings - Fork 50
feat: df.apply(axis=1)
to support remote function with mutiple params
#851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fdee925
5e60089
8b82a14
bf917af
c2d5681
c09a371
c133a5d
faeb9a1
d926cac
267d28b
ea5663d
fcd6f3a
f57e1e3
fd78a9e
be7988d
3f8ebcf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,19 +191,27 @@ def normalized_impl(args: typing.Sequence[ibis_types.Value], op: ops.RowOp): | |
|
||
return decorator | ||
|
||
def register_nary_op(self, op_ref: typing.Union[ops.NaryOp, type[ops.NaryOp]]): | ||
def register_nary_op( | ||
self, op_ref: typing.Union[ops.NaryOp, type[ops.NaryOp]], pass_op: bool = False | ||
): | ||
""" | ||
Decorator to register a nary op implementation. | ||
|
||
Args: | ||
op_ref (NaryOp or NaryOp type): | ||
Class or instance of operator that is implemented by the decorated function. | ||
pass_op (bool): | ||
Set to true if implementation takes the operator object as the last argument. | ||
This is needed for parameterized ops where parameters are part of op object. | ||
""" | ||
key = typing.cast(str, op_ref.name) | ||
|
||
def decorator(impl: typing.Callable[..., ibis_types.Value]): | ||
def normalized_impl(args: typing.Sequence[ibis_types.Value], op: ops.RowOp): | ||
return impl(*args) | ||
if pass_op: | ||
return impl(*args, op=op) | ||
else: | ||
return impl(*args) | ||
|
||
self._register(key, normalized_impl) | ||
return impl | ||
|
@@ -1444,6 +1452,7 @@ def clip_op( | |
) | ||
|
||
|
||
# N-ary Operations | ||
@scalar_op_compiler.register_nary_op(ops.case_when_op) | ||
def case_when_op(*cases_and_outputs: ibis_types.Value) -> ibis_types.Value: | ||
# ibis can handle most type coercions, but we need to force bool -> int | ||
|
@@ -1463,6 +1472,19 @@ def case_when_op(*cases_and_outputs: ibis_types.Value) -> ibis_types.Value: | |
return case_val.end() | ||
|
||
|
||
@scalar_op_compiler.register_nary_op(ops.NaryRemoteFunctionOp, pass_op=True) | ||
def nary_remote_function_op_impl( | ||
*operands: ibis_types.Value, op: ops.NaryRemoteFunctionOp | ||
): | ||
ibis_node = getattr(op.func, "ibis_node", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't need to do this in the current pr - but we need to move away from storing ibis values in the op definition. We will want to generate this at compile-time only to allow non-ibis compilation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack. b/356686746 |
||
if ibis_node is None: | ||
raise TypeError( | ||
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}" | ||
) | ||
result = ibis_node(*operands) | ||
return result | ||
|
||
|
||
# Helpers | ||
def is_null(value) -> bool: | ||
# float NaN/inf should be treated as distinct from 'true' null values | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3433,9 +3433,9 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame: | |
raise ValueError(f"na_action={na_action} not supported") | ||
|
||
# TODO(shobs): Support **kwargs | ||
# Reproject as workaround to applying filter too late. This forces the filter | ||
# to be applied before passing data to remote function, protecting from bad | ||
# inputs causing errors. | ||
# Reproject as workaround to applying filter too late. This forces the | ||
# filter to be applied before passing data to remote function, | ||
# protecting from bad inputs causing errors. | ||
reprojected_df = DataFrame(self._block._force_reproject()) | ||
return reprojected_df._apply_unary_op( | ||
ops.RemoteFunctionOp(func=func, apply_on_null=(na_action is None)) | ||
|
@@ -3448,65 +3448,99 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs): | |
category=bigframes.exceptions.PreviewWarning, | ||
) | ||
|
||
# Early check whether the dataframe dtypes are currently supported | ||
# in the remote function | ||
# NOTE: Keep in sync with the value converters used in the gcf code | ||
# generated in remote_function_template.py | ||
remote_function_supported_dtypes = ( | ||
bigframes.dtypes.INT_DTYPE, | ||
bigframes.dtypes.FLOAT_DTYPE, | ||
bigframes.dtypes.BOOL_DTYPE, | ||
bigframes.dtypes.BYTES_DTYPE, | ||
bigframes.dtypes.STRING_DTYPE, | ||
) | ||
supported_dtypes_types = tuple( | ||
type(dtype) | ||
for dtype in remote_function_supported_dtypes | ||
if not isinstance(dtype, pandas.ArrowDtype) | ||
) | ||
# Check ArrowDtype separately since multiple BigQuery types map to | ||
# ArrowDtype, including BYTES and TIMESTAMP. | ||
supported_arrow_types = tuple( | ||
dtype.pyarrow_dtype | ||
for dtype in remote_function_supported_dtypes | ||
if isinstance(dtype, pandas.ArrowDtype) | ||
) | ||
supported_dtypes_hints = tuple( | ||
str(dtype) for dtype in remote_function_supported_dtypes | ||
) | ||
|
||
for dtype in self.dtypes: | ||
if ( | ||
# Not one of the pandas/numpy types. | ||
not isinstance(dtype, supported_dtypes_types) | ||
# And not one of the arrow types. | ||
and not ( | ||
isinstance(dtype, pandas.ArrowDtype) | ||
and any( | ||
dtype.pyarrow_dtype.equals(arrow_type) | ||
for arrow_type in supported_arrow_types | ||
) | ||
) | ||
): | ||
raise NotImplementedError( | ||
f"DataFrame has a column of dtype '{dtype}' which is not supported with axis=1." | ||
f" Supported dtypes are {supported_dtypes_hints}." | ||
) | ||
|
||
# Check if the function is a remote function | ||
if not hasattr(func, "bigframes_remote_function"): | ||
shobsi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError("For axis=1 a remote function must be used.") | ||
|
||
# Serialize the rows as json values | ||
block = self._get_block() | ||
rows_as_json_series = bigframes.series.Series( | ||
block._get_rows_as_json_values() | ||
) | ||
is_row_processor = getattr(func, "is_row_processor") | ||
if is_row_processor: | ||
# Early check whether the dataframe dtypes are currently supported | ||
# in the remote function | ||
# NOTE: Keep in sync with the value converters used in the gcf code | ||
# generated in remote_function_template.py | ||
remote_function_supported_dtypes = ( | ||
bigframes.dtypes.INT_DTYPE, | ||
bigframes.dtypes.FLOAT_DTYPE, | ||
bigframes.dtypes.BOOL_DTYPE, | ||
bigframes.dtypes.BYTES_DTYPE, | ||
bigframes.dtypes.STRING_DTYPE, | ||
) | ||
supported_dtypes_types = tuple( | ||
type(dtype) | ||
for dtype in remote_function_supported_dtypes | ||
if not isinstance(dtype, pandas.ArrowDtype) | ||
) | ||
# Check ArrowDtype separately since multiple BigQuery types map to | ||
# ArrowDtype, including BYTES and TIMESTAMP. | ||
supported_arrow_types = tuple( | ||
dtype.pyarrow_dtype | ||
for dtype in remote_function_supported_dtypes | ||
if isinstance(dtype, pandas.ArrowDtype) | ||
) | ||
supported_dtypes_hints = tuple( | ||
str(dtype) for dtype in remote_function_supported_dtypes | ||
) | ||
|
||
# Apply the function | ||
result_series = rows_as_json_series._apply_unary_op( | ||
ops.RemoteFunctionOp(func=func, apply_on_null=True) | ||
) | ||
for dtype in self.dtypes: | ||
if ( | ||
# Not one of the pandas/numpy types. | ||
not isinstance(dtype, supported_dtypes_types) | ||
# And not one of the arrow types. | ||
and not ( | ||
isinstance(dtype, pandas.ArrowDtype) | ||
and any( | ||
dtype.pyarrow_dtype.equals(arrow_type) | ||
for arrow_type in supported_arrow_types | ||
) | ||
) | ||
): | ||
raise NotImplementedError( | ||
f"DataFrame has a column of dtype '{dtype}' which is not supported with axis=1." | ||
f" Supported dtypes are {supported_dtypes_hints}." | ||
) | ||
|
||
# Serialize the rows as json values | ||
block = self._get_block() | ||
rows_as_json_series = bigframes.series.Series( | ||
block._get_rows_as_json_values() | ||
) | ||
|
||
# Apply the function | ||
result_series = rows_as_json_series._apply_unary_op( | ||
ops.RemoteFunctionOp(func=func, apply_on_null=True) | ||
) | ||
else: | ||
# This is a special case where we are providing not-pandas-like | ||
# extension. If the remote function can take one or more params | ||
# then we assume that here the user intention is to use the | ||
# column values of the dataframe as arguments to the function. | ||
# For this to work the following condition must be true: | ||
# 1. The number or input params in the function must be same | ||
# as the number of columns in the dataframe | ||
# 2. The dtypes of the columns in the dataframe must be | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to accept compatible dtypes? eg the column is int, but the function takes decimal? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. function taking decimal is not something we support right now. There is a longer term desire to expand the datatype support.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having said that, are there other places where we reconcile dtypes across dataframes or in operations? |
||
# compatible with the data types of the input params | ||
# 3. The order of the columns in the dataframe must correspond | ||
# to the order of the input params in the function | ||
udf_input_dtypes = getattr(func, "input_dtypes") | ||
if len(udf_input_dtypes) != len(self.columns): | ||
raise ValueError( | ||
f"Remote function takes {len(udf_input_dtypes)} arguments but DataFrame has {len(self.columns)} columns." | ||
) | ||
if udf_input_dtypes != tuple(self.dtypes.to_list()): | ||
raise ValueError( | ||
f"Remote function takes arguments of types {udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}." | ||
) | ||
|
||
series_list = [self[col] for col in self.columns] | ||
# Reproject as workaround to applying filter too late. This forces the | ||
# filter to be applied before passing data to remote function, | ||
# protecting from bad inputs causing errors. | ||
reprojected_series = bigframes.series.Series( | ||
series_list[0]._block._force_reproject() | ||
) | ||
Comment on lines
+3538
to
+3540
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wyh do we need a force_reproject? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am just copy-pasting the pattern introduced in this change: ff3bb89#diff-8718ceb6a8f6b68d7b06a15e84043fb866c500d5bfb1f33ad8c945f06815a140 Is the reasoning (got a bit detached unintentionally, sitting at the beginning of the function) still valid? # Reproject as workaround to applying filter too late. This forces the filter
# to be applied before passing data to remote function, protecting from bad
# inputs causing errors. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually it does make a difference, quickly tested in #874 and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved the comment back close to reproject, PTAL |
||
result_series = reprojected_series._apply_nary_op( | ||
ops.NaryRemoteFunctionOp(func=func), series_list[1:] | ||
) | ||
result_series.name = None | ||
|
||
# Return Series with materialized result so that any error in the remote | ||
|
Uh oh!
There was an error while loading. Please reload this page.