Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Add array API support for Ridge(solver='cholesky') #29318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
Loading
from

Conversation

ogrisel
Copy link
Member

@ogrisel ogrisel commented Jun 20, 2024

Note: this PR uses xp.linalg.solve without being able to pass assume_a="pos" as is possible in scipy. As a result, the precise nature of the square matrix decomposition used to solver the linear system is not necessarily a Cholesky decomposition anymore ({S,D}GESV based on LU instead of {S,D}POSV based on Cholesky in LAPACK). However this is still very fast in practice (e.g. on PyTorch) so I think it's not a problem.

TODO:

  • Make sure CPU tests pass with good coverage
    • debug the multi-target failure with array-api-strict
    • add array API tests for sample_weight != None
    • add tests to check fallback behavior to svd in the absence of a standard xp.linalg.LinalgError exception...
  • Test with CUDA: triggered a run here (the 8 failures are known and being fixed independently in fix: mps device support in entropy #29321)
  • Run some benchmarks
  • Profile with py-spy, viztracer and/or pytorch profiler to check that 90%+ of the time is spent in xp.linalg calls as expected..

Benchmark results

Benchmark script:

import torch
from sklearn.linear_model import Ridge
import numpy as np
from time import perf_counter
from sklearn import set_config

set_config(array_api_dispatch=True)


n_samples, n_features = int(5e4), int(1e4)
ridge = Ridge(alpha=1.0, solver="cholesky")

print(f"Generating data with shape {(n_samples, n_features)}...")
X_cuda = torch.randn(n_samples, n_features, device="cuda")
w = torch.randn(n_features, device="cuda")
y_cuda = X_cuda @ w + 0.1 * torch.randn(n_samples, device="cuda")
X_cpu, y_cpu = X_cuda.cpu(), y_cuda.cpu()
X_np, y_np = X_cpu.numpy(), y_cpu.numpy()
print(f"Data size: {X_np.nbytes / 1e6:.1f} MB")

tic = perf_counter()
ridge_cuda = ridge.fit(X_cuda, y_cuda)
print(ridge_cuda.coef_[:5])
toc = perf_counter()
print(f"PyTorch GPU Ridge: {toc - tic:.2f} s")

tic = perf_counter()
ridge_cuda = ridge.fit(X_cpu, y_cpu)
print(ridge_cuda.coef_[:5])
toc = perf_counter()
print(f"PyTorch CPU Ridge: {toc - tic:.2f} s")

tic = perf_counter()
ridge_np = ridge.fit(X_np, y_np)
print(ridge_np.coef_[:5])
toc = perf_counter()
print(f"NumPy Ridge: {toc - tic:.2f} s")
Generating data with shape (50000, 10000)...
Data size: 2000.0 MB
tensor([-1.5775,  0.5475, -1.4703, -0.2500,  0.9747], device='cuda:0')
PyTorch GPU Ridge: 0.89 s
tensor([-1.5775,  0.5475, -1.4703, -0.2500,  0.9747])
PyTorch CPU Ridge: 18.50 s
[-1.5774863   0.5474608  -1.4703354  -0.25000554  0.97471786]
NumPy Ridge: 12.38 s

So a typical 10 to 15x speed-up between GPU (NVIDIA A100 in this case) vs CPU (Intel(R) Xeon(R) Silver 4214R CPU @ 2.40GHz with 20 physical cores).

Note that as soon as one of the 2 data dimensions is lower than 1e3 or so, the intermediate datastructure X.T @ X or X @ X.T fits in CPU cache and the difference between CPU and GPU is not as dramatic, e.g.:

Generating data with shape (1000000, 1000)...
Data size: 4000.0 MB
tensor([ 0.4550,  0.7365,  0.3112, -0.1312, -0.0632], device='cuda:0')
PyTorch GPU Ridge: 0.62 s
tensor([ 0.4550,  0.7364,  0.3112, -0.1312, -0.0632])
PyTorch CPU Ridge: 4.46 s
[ 0.4550036   0.7364387   0.31120822 -0.13120973 -0.06319333]
NumPy Ridge: 6.30 s

Profiling results

I adapted the benchmark script to use the PyTorch profiler as follows:

import torch
from torch.profiler import profile, ProfilerActivity
from sklearn.linear_model import Ridge
from sklearn import set_config

set_config(array_api_dispatch=True)


n_samples, n_features = int(1e7), int(1e2)
ridge = Ridge(alpha=1.0, solver="cholesky")

print(f"Generating data with shape {(n_samples, n_features)}...")
X_cuda = torch.randn(n_samples, n_features, device="cuda")
w = torch.randn(n_features, device="cuda")
y_cuda = X_cuda @ w + 0.1 * torch.randn(n_samples, device="cuda")
print(f"Data size: {X_cuda.nbytes / 1e6:.1f} MB")


with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    with_stack=True,
) as prof:
    ridge_cuda = ridge.fit(X_cuda, y_cuda)


print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Here are some results for various data shapes:

Generating data with shape (10000000, 100)...
Data size: 4000.0 MB
[W kineto_shim.cpp:362] Adding profiling metadata requires using torch.profiler with Kineto support (USE_KINETO=1)
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  aten::matmul         0.03%      37.000us        30.95%      37.319ms      12.440ms      31.000us         0.03%      37.379ms      12.460ms             3  
                      aten::mm        30.71%      37.026ms        30.71%      37.026ms      18.513ms      37.088ms        30.56%      37.088ms      18.544ms             2  
            aten::linalg_solve         0.03%      37.000us        26.02%      31.378ms      31.378ms      23.000us         0.02%      31.383ms      31.383ms             1  
         aten::linalg_solve_ex         0.01%      17.000us        20.88%      25.175ms      25.175ms      17.000us         0.01%      25.193ms      25.193ms             1  
        aten::_linalg_solve_ex         3.71%       4.475ms        20.86%      25.158ms      25.158ms      57.000us         0.05%      25.176ms      25.176ms             1  
     aten::linalg_lu_factor_ex        15.98%      19.270ms        15.99%      19.280ms      19.280ms      23.676ms        19.51%      23.690ms      23.690ms             1  
                   aten::copy_         5.13%       6.187ms         5.13%       6.187ms     562.455us      19.314ms        15.92%      19.314ms       1.756ms            11  
                      aten::to         0.04%      43.000us         8.06%       9.715ms     883.182us      75.000us         0.06%      16.340ms       1.485ms            11  
                aten::_to_copy         0.06%      73.000us         8.02%       9.672ms       2.418ms      61.000us         0.05%      16.265ms       4.066ms             4  
                     aten::sum         4.30%       5.188ms         4.31%       5.197ms       1.299ms      10.801ms         8.90%      10.821ms       2.705ms             4  
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 120.581ms
Self CUDA time total: 121.356ms
Generating data with shape (100000, 10000)...
Data size: 4000.0 MB
[W kineto_shim.cpp:362] Adding profiling metadata requires using torch.profiler with Kineto support (USE_KINETO=1)
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     aten::matmul         0.00%      52.000us        86.99%        1.440s     479.850ms      33.000us         0.00%        1.440s     479.859ms             3  
                         aten::mm        86.97%        1.439s        86.97%        1.439s     719.622ms        1.439s        86.93%        1.439s     719.644ms             2  
               aten::linalg_solve         0.00%      32.000us         9.88%     163.571ms     163.571ms      28.000us         0.00%     163.571ms     163.571ms             1  
            aten::linalg_solve_ex         0.00%      17.000us         9.75%     161.333ms     161.333ms      17.000us         0.00%     161.335ms     161.335ms             1  
           aten::_linalg_solve_ex         0.28%       4.586ms         9.75%     161.316ms     161.316ms      51.000us         0.00%     161.318ms     161.318ms             1  
        aten::linalg_lu_factor_ex         8.70%     144.047ms         8.71%     144.067ms     144.067ms     149.330ms         9.02%     149.948ms     149.948ms             1  
                      aten::copy_         0.14%       2.246ms         0.14%       2.246ms     187.167us      15.813ms         0.96%      15.813ms       1.318ms            12  
                         aten::to         0.00%      35.000us         0.35%       5.715ms     519.545us      63.000us         0.00%      12.262ms       1.115ms            11  
                   aten::_to_copy         0.00%      78.000us         0.34%       5.680ms       1.420ms      62.000us         0.00%      12.199ms       3.050ms             4  
            aten::linalg_lu_solve         0.28%       4.689ms         0.76%      12.640ms      12.640ms       5.931ms         0.36%      11.293ms      11.293ms             1  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.655s
Self CUDA time total: 1.656s
Generating data with shape (10000, 100000)...
Data size: 4000.0 MB
[W kineto_shim.cpp:362] Adding profiling metadata requires using torch.profiler with Kineto support (USE_KINETO=1)
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     aten::matmul         0.00%      46.000us         0.28%       5.246ms       1.749ms      37.000us         0.00%        1.664s     554.582ms             3  
                         aten::mm         0.27%       4.967ms         0.27%       4.967ms       2.483ms        1.663s        88.70%        1.663s     831.736ms             2  
               aten::linalg_solve         0.00%      33.000us         8.70%     162.983ms     162.983ms      32.000us         0.00%     162.985ms     162.985ms             1  
            aten::linalg_solve_ex         0.00%      18.000us         8.57%     160.581ms     160.581ms      17.000us         0.00%     160.582ms     160.582ms             1  
           aten::_linalg_solve_ex         0.25%       4.691ms         8.57%     160.563ms     160.563ms      67.000us         0.00%     160.565ms     160.565ms             1  
        aten::linalg_lu_factor_ex         7.60%     142.498ms         7.61%     142.529ms     142.529ms     147.736ms         7.88%     148.362ms     148.362ms             1  
                      aten::copy_         0.13%       2.411ms         0.13%       2.411ms     185.462us      15.968ms         0.85%      15.968ms       1.228ms            13  
                         aten::to         0.00%      38.000us         0.32%       5.923ms     538.455us      65.000us         0.00%      12.466ms       1.133ms            11  
                   aten::_to_copy         0.00%      76.000us         0.31%       5.885ms       1.471ms      62.000us         0.00%      12.401ms       3.100ms             4  
            aten::linalg_lu_solve         0.25%       4.628ms         0.71%      13.323ms      13.323ms       5.872ms         0.31%      12.114ms      12.114ms             1  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.874s
Self CUDA time total: 1.875s

Conclusions:

  • For very rectangular data, a significant time is spent in data copy and auxiliary GPU operations. The linear system solution is just a small fraction of the total time. But it's very fast so not a big problem.
  • For more square data, most (98%) of the time is spent computing the matmuls (i.e. X.T @ X or X @ X.T depending on the shape) and in the LU factorization to solver the resulting square linear system as expected.

So all in all, I think the profiling results show that it's behaving as expected.

/cc @EdAbati (I gave it a try finally).

Copy link

github-actions bot commented Jun 20, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 5bb5ac7. Link to the linter CI: here

Copy link
Contributor

@EdAbati EdAbati left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!
Apologies, but I still didn't have time to get the code in shape for the PR. It's better this way so I can learn from your code :)

Regarding the benchmarks, out of curiosity, are you tracking the perfomance before and after the Array API? and/or comparing the different array libraries?

@ogrisel
Copy link
Member Author

ogrisel commented Jun 20, 2024

Regarding the benchmarks, out of curiosity, are you tracking the perfomance before and after the Array API? and/or comparing the different array libraries?

I just wanted to compare the same workflow on different libraries/devices.

@ogrisel
Copy link
Member Author

ogrisel commented Jun 20, 2024

The CUDA failures seem unrelated to this PR. I suspect they also happen on main.

f"reshape with copy=False is not compatible with shape {shape} "
"for the memory layout of the input array."
)
return output
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewers: while this is not strictly needed for this PR, this something that should be implemented as per the array API spec:

copy (Optional[bool]) – whether or not to copy the input array. If True, the function must always copy. If False, the function must never copy. If None, the function must avoid copying, if possible, and may copy otherwise. Default: None.

Being able to set reshape with copy=False would have helped me a lot spot differences between A.flat and A_flat = xp.reshape(A, (-1,), copy=False) while debugging early test failures when developing this PR.

@ogrisel
Copy link
Member Author

ogrisel commented Jun 28, 2024

I pushed a new commit (5bb5ac7) to trigger rank deficient related exceptions. There is no standard xp.linalg.LinalgError exception to catch in the spec. So let me run the CI (both regular and CUDA) to see if we can get them all.


device : device
`device` object (see the "Device Support" section of the array API spec).
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewers: for some reason, the numpy docstring test started to complain about the lack of params / returns in this docstring only on this PR instead of the PR that introduced it.

I am not sure why but I decided to fix it here. We can take it apart in a dedicated PR if needed.

@StefanieSenger
Copy link
Contributor

StefanieSenger commented Jul 29, 2024

Note: this PR uses xp.linalg.solve without being able to pass assume_a="pos" as is possible in scipy. As a result, the precise nature of the square matrix decomposition used to solver the linear system is not necessarily a Cholesky decomposition anymore ({S,D}GESV based on LU instead of {S,D}POSV based on Cholesky in LAPACK). However this is still very fast in practice (e.g. on PyTorch) so I think it's not a problem.

In case this is somewhat useful: Maybe you want to add a comment in _solve_cholesky_kernel() about the different solver by default, in case this part needs to be debugged some later time to help the developer then working on it.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ogrisel I think there has been considerable work done in this PR. Is there something remaining?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.