Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Automatically move y (and sample_weight) to the same device and namespace as X  #28668

Copy link
Copy link
Open
@betatim

Description

@betatim
Issue body actions

(From #27800 (comment) by @ogrisel)

The proposal/idea is to allow y to not be on the same device (and namespace?) as X when using Array API inputs. Currently we require/assume that they are on the same device and namespace, it is a requirement. However pipelines can't modify y which means it is not possible to move from CPU to GPU as one of the steps of the pipeline, the whole pipeline has to stay on one device. The below details some example and code to motivate allowing X ad y being on different devices (and namespaces).


Suppose we have:

>>> import torch
>>> from sklearn import set_config
>>> from sklearn.datasets import make_regression
>>> from sklearn.linear_model import Ridge
>>> set_config(array_api_dispatch=True)
>>> X, y = make_regression(n_samples=int(1e5), n_features=int(1e3), random_state=0)
>>> X_torch_cuda = torch.tensor(X).to("cuda")
>>> y_torch_cuda = torch.tensor(y).to("cuda")

I did a quick benchmark with timeit on a host with a 32 cores CPU and an A100 GPU: we get a bit more than 10x speed-up (which is in the range of what I would have expected):

>>> %time Ridge(solver="svd").fit(X, y)
CPU times: user 1min 29s, sys: 1min 4s, total: 2min 34s
Wall time: 6.18 s
Ridge(solver='svd')
>>> %time Ridge(solver="svd").fit(X_torch_cuda, y_torch_cuda)
CPU times: user 398 ms, sys: 2.74 ms, total: 401 ms
Wall time: 402 ms
Ridge(solver='svd')

I also tried the following:

>>> Ridge(solver="svd").fit(X_torch_cuda, y)
Traceback (most recent call last):
  Cell In[36], line 1
    Ridge(solver="svd").fit(X_torch_cuda, y)
  File ~/code/scikit-learn/sklearn/base.py:1194 in wrapper
    return fit_method(estimator, *args, **kwargs)
  File ~/code/scikit-learn/sklearn/linear_model/_ridge.py:1197 in fit
    device_ = device(*input_arrays)
  File ~/code/scikit-learn/sklearn/utils/_array_api.py:104 in device
    raise ValueError("Input arrays use different devices.")
ValueError: Input arrays use different devices.

I think it might be reasonable to expect this pattern to fail in general. explicitly ask the the user to provide inputs with consistently allocated data buffers.

However, we might want to be more lenient for the particular case of y (and sample_weight) and change Ridge.fit to automatically move y to the same namespace and device as X:

y = xp.asarray(y, device=device(X))

The reason would be to improve the usability of the Array API for the following pipelines:

X_pandas_df, y_pandas_series = fetch_some_pandas_data()

pipeline = make_pipeline(
    some_column_transformer(),  # works on CPU on the input dataframe
    FunctionTransformer(func=lambda X: torch.tensor(X).to("float32").to("cuda")),
    Ridge(solver="svd"),
)
pipeline.fit(X_pandas_df, y_pandas_series)

The pipeline steps can only transform X and not y (or sample_weight). So it means that user would have to instead call:

pipeline.fit(X_pandas_df, torch.tensor(y_pandas_series).to("cuda"))

which feels a bit weird/cumbersome to me.

This might not be a big deal though and I don't want to delay this PR to get a consensus on this particular point: I think it's fine the way it is for now but we might want come back to the UX of pipelines with Array API steps later. I will open a dedicated issue.


If we take action there are a few estimators that need double checking. For example LinearDiscriminantAnalysis currently raises low level PyTorch exceptions when X and y are on different devices.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    Discussion
    Show more project fields

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Morty Proxy This is a proxified and sanitized view of the page, visit original site.