Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Add support for array API to RidgeCV #27961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 53 commits into
base: main
Choose a base branch
Loading
from

Conversation

jeromedockes
Copy link
Contributor

@jeromedockes jeromedockes commented Dec 14, 2023

Reference Issues/PRs

Towards #26024.

This PR extends the one for Ridge (still WIP, #27800) to use the array API in RidgeCV and RidgeClassifierCV (when cv="gcv")

What does this implement/fix? Explain your changes.

this could make those estimators faster as an important part of their computational cost is due to compute either an eigendecomposition of XX^T or an SVD of X

Any other comments?

The _RidgeGCV has numerical precision issues when computations are done in float32, which is why ATM in the main branch it always uses float64
I'm not sure what should be done for array API inputs on devices that do not have float64

not handled yet:

  • RidgeClassifierCV

Copy link

github-actions bot commented Dec 14, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 07af173. Link to the linter CI: here

@jeromedockes
Copy link
Contributor Author

I think the test failures for Ridge and RidgeCV arise from r2_score and will be handled in #27904
For RidgeClassifierCV we need to support the array API in LabelBinarizer

@ogrisel
Copy link
Member

ogrisel commented Mar 14, 2024

While I am thinking about it, please don't forget to update:

if sparse.issparse(X):
dtype = np.float64
else:
dtype = [xp.float64, xp.float32]
Copy link
Member

@ogrisel ogrisel May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Contrary to what I said in this morning meeting, I think we might want to implement the following logic:

  • if the input namespace/device supports xp.float64 upcasting, then do the upcast (as we currently do with NumPy)
  • if not (e.g. pytorch + MPS device combination), accept that we have degraded numerical performance, adjust the tolerance in the tests accordingly and document this limited numerical precision guarantee in our Array API doc.

I think this is the strategy we are leaning towards in the review of #27113. During the review of the r2_score PR, I believe that @adrinjalali preferred that approach.

In a future PR, we might decide to drop the float32 -> float64 upcast in general for this estimator (as it silently triggers a potentially very large and unexpected memory allocation which is a usability problem in itself, even with NumPy) but I would rather make this decision independently of Array API support.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would you recommend I check if the upcasting is possible? should I temporarily copy the max_precision_float_dtype and supported_float_dtypes changes from 27113 until it is merged? or is there already a utility in scikit-learn for checking that which I missed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to copy with a TODO comment to remove redundant code once #27113 is merged to be able to decouple the 2 reviews.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we do the upcast with what precision should we store the coefficients and intercept? I guess for prediction we do not need the extra precision so we should use X's original dtype?

sklearn/linear_model/_ridge.py Outdated Show resolved Hide resolved
sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat! From my point of view LGTM. But I haven't checked the tests or mathematical correctness.

w = 1.0 / (eigvals + alpha)
if self.fit_intercept:
# the vector containing the square roots of the sample weights (1
# when no sample weights) is the eigenvector of XX^T which
# corresponds to the intercept; we cancel the regularization on
# this dimension. the corresponding eigenvalue is
# sum(sample_weight).
normalized_sw = sqrt_sw / np.linalg.norm(sqrt_sw)
norm = xp.linalg.vector_norm if is_array_api else np.linalg.norm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one could write a lightweight narwhals version for array API

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have array_api_compat and our own numpy wrapper that we could leverage to help us do that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a new array-api-extra project that has started to implement missing utilities on top of what is standardized in the spec and therefore implemented in the array-api-compat project:

https://data-apis.org/array-api-extra/

@github-actions github-actions bot removed the CUDA CI label Dec 11, 2024
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed 84535ef to actually make the tests pass with PyTorch and the MPS device.

I also pushed b64baaa to revert most of the extra complexity in decision_function that did not seem justified when running the tests locally (including with PyTorch and MPS).

Apart from this, here is a new round of feedback.

sklearn/linear_model/_ridge.py Show resolved Hide resolved
sklearn/linear_model/_ridge.py Outdated Show resolved Hide resolved
sklearn/linear_model/_ridge.py Outdated Show resolved Hide resolved
sklearn/linear_model/_ridge.py Outdated Show resolved Hide resolved
sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
sklearn/linear_model/_ridge.py Outdated Show resolved Hide resolved
sklearn/utils/validation.py Outdated Show resolved Hide resolved
@ogrisel ogrisel added the CUDA CI label Feb 6, 2025
@github-actions github-actions bot removed the CUDA CI label Feb 6, 2025
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once coverage is improved a bit (for the easy to cover lines).

doc/modules/array_api.rst Outdated Show resolved Hide resolved
sklearn/linear_model/_ridge.py Show resolved Hide resolved
# The RidgeGCV is not very numerically stable in float32. It casts the
# input to float64 unless the device and array api combination makes it
# impossible.
tols = {"rtol": 1e-3, "atol": 1e-3}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can only be covered on devices where the max float precision is float32

new_arrays.append(xp.asarray(array, device=device_))
continue
except Exception:
# direct conversion to a different library may fail in which
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a test here to cover those lines (I checked it does locally) but only when both torch and array-api-strict are installed, is there such a configuration in one of the CI jobs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.