-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Add array API support for Ridge(solver='cholesky') #29318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Apologies, but I still didn't have time to get the code in shape for the PR. It's better this way so I can learn from your code :)
Regarding the benchmarks, out of curiosity, are you tracking the perfomance before and after the Array API? and/or comparing the different array libraries?
I just wanted to compare the same workflow on different libraries/devices. |
The CUDA failures seem unrelated to this PR. I suspect they also happen on |
f"reshape with copy=False is not compatible with shape {shape} " | ||
"for the memory layout of the input array." | ||
) | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to reviewers: while this is not strictly needed for this PR, this something that should be implemented as per the array API spec:
copy (Optional[bool]) – whether or not to copy the input array. If True, the function must always copy. If False, the function must never copy. If None, the function must avoid copying, if possible, and may copy otherwise. Default: None.
Being able to set reshape with copy=False
would have helped me a lot spot differences between A.flat
and A_flat = xp.reshape(A, (-1,), copy=False)
while debugging early test failures when developing this PR.
I pushed a new commit (5bb5ac7) to trigger rank deficient related exceptions. There is no standard |
|
||
device : device | ||
`device` object (see the "Device Support" section of the array API spec). | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to reviewers: for some reason, the numpy docstring test started to complain about the lack of params / returns in this docstring only on this PR instead of the PR that introduced it.
I am not sure why but I decided to fix it here. We can take it apart in a dedicated PR if needed.
In case this is somewhat useful: Maybe you want to add a comment in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ogrisel I think there has been considerable work done in this PR. Is there something remaining?
Note: this PR uses
xp.linalg.solve
without being able to passassume_a="pos"
as is possible in scipy. As a result, the precise nature of the square matrix decomposition used to solver the linear system is not necessarily a Cholesky decomposition anymore ({S,D}GESV based on LU instead of {S,D}POSV based on Cholesky in LAPACK). However this is still very fast in practice (e.g. on PyTorch) so I think it's not a problem.TODO:
array-api-strict
sample_weight != None
xp.linalg.LinalgError
exception...mps
device support inentropy
#29321)xp.linalg
calls as expected..Benchmark results
Benchmark script:
So a typical 10 to 15x speed-up between GPU (NVIDIA A100 in this case) vs CPU (Intel(R) Xeon(R) Silver 4214R CPU @ 2.40GHz with 20 physical cores).
Note that as soon as one of the 2 data dimensions is lower than 1e3 or so, the intermediate datastructure
X.T @ X
orX @ X.T
fits in CPU cache and the difference between CPU and GPU is not as dramatic, e.g.:Profiling results
I adapted the benchmark script to use the PyTorch profiler as follows:
Here are some results for various data shapes:
Conclusions:
X.T @ X
orX @ X.T
depending on the shape) and in the LU factorization to solver the resulting square linear system as expected.So all in all, I think the profiling results show that it's behaving as expected.
/cc @EdAbati (I gave it a try finally).