-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Add support for array API to RidgeCV #27961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I think the test failures for Ridge and RidgeCV arise from r2_score and will be handled in #27904 |
While I am thinking about it, please don't forget to update: |
6be83f7
to
54dbac2
Compare
sklearn/linear_model/_ridge.py
Outdated
if sparse.issparse(X): | ||
dtype = np.float64 | ||
else: | ||
dtype = [xp.float64, xp.float32] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Contrary to what I said in this morning meeting, I think we might want to implement the following logic:
- if the input namespace/device supports
xp.float64
upcasting, then do the upcast (as we currently do with NumPy) - if not (e.g. pytorch + MPS device combination), accept that we have degraded numerical performance, adjust the tolerance in the tests accordingly and document this limited numerical precision guarantee in our Array API doc.
I think this is the strategy we are leaning towards in the review of #27113. During the review of the r2_score
PR, I believe that @adrinjalali preferred that approach.
In a future PR, we might decide to drop the float32 -> float64 upcast in general for this estimator (as it silently triggers a potentially very large and unexpected memory allocation which is a usability problem in itself, even with NumPy) but I would rather make this decision independently of Array API support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how would you recommend I check if the upcasting is possible? should I temporarily copy the max_precision_float_dtype
and supported_float_dtypes
changes from 27113 until it is merged? or is there already a utility in scikit-learn for checking that which I missed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to copy with a TODO comment to remove redundant code once #27113 is merged to be able to decouple the 2 reviews.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when we do the upcast with what precision should we store the coefficients and intercept? I guess for prediction we do not need the extra precision so we should use X's original dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is neat! From my point of view LGTM. But I haven't checked the tests or mathematical correctness.
w = 1.0 / (eigvals + alpha) | ||
if self.fit_intercept: | ||
# the vector containing the square roots of the sample weights (1 | ||
# when no sample weights) is the eigenvector of XX^T which | ||
# corresponds to the intercept; we cancel the regularization on | ||
# this dimension. the corresponding eigenvalue is | ||
# sum(sample_weight). | ||
normalized_sw = sqrt_sw / np.linalg.norm(sqrt_sw) | ||
norm = xp.linalg.vector_norm if is_array_api else np.linalg.norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one could write a lightweight narwhals version for array API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have array_api_compat
and our own numpy wrapper that we could leverage to help us do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a new array-api-extra
project that has started to implement missing utilities on top of what is standardized in the spec and therefore implemented in the array-api-compat
project:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed 84535ef to actually make the tests pass with PyTorch and the MPS device.
I also pushed b64baaa to revert most of the extra complexity in decision_function
that did not seem justified when running the tests locally (including with PyTorch and MPS).
Apart from this, here is a new round of feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once coverage is improved a bit (for the easy to cover lines).
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
# The RidgeGCV is not very numerically stable in float32. It casts the | ||
# input to float64 unless the device and array api combination makes it | ||
# impossible. | ||
tols = {"rtol": 1e-3, "atol": 1e-3} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can only be covered on devices where the max float precision is float32
new_arrays.append(xp.asarray(array, device=device_)) | ||
continue | ||
except Exception: | ||
# direct conversion to a different library may fail in which |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a test here to cover those lines (I checked it does locally) but only when both torch and array-api-strict are installed, is there such a configuration in one of the CI jobs?
Reference Issues/PRs
Towards #26024.
This PR extends the one for Ridge (still WIP, #27800) to use the array API in
RidgeCV
andRidgeClassifierCV
(when cv="gcv")What does this implement/fix? Explain your changes.
this could make those estimators faster as an important part of their computational cost is due to compute either an eigendecomposition of XX^T or an SVD of X
Any other comments?
The
_RidgeGCV
has numerical precision issues when computations are done in float32, which is why ATM in the main branch it always uses float64I'm not sure what should be done for array API inputs on devices that do not have float64
not handled yet: