Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Add array api support for jaccard score #31204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 9, 2025

Conversation

OmarManzoor
Copy link
Contributor

Reference Issues/PRs

Towards #26024

What does this implement/fix? Explain your changes.

  • Adds array api support for jaccard score

Any other comments?

Copy link

github-actions bot commented Apr 15, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: c403043. Link to the linter CI: here

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Apr 15, 2025

Some benchmarks

data size = 1e7
dtype = np.int64

average Orignal flow Pytorch CPU Pytorch CUDA
micro 1.0681 2.63489 0.08644
macro 2.12834 5.26227 0.15068
weighted 3.19726 7.89072 0.21451

@OmarManzoor
Copy link
Contributor Author

CC: @ogrisel @betatim @adrinjalali for reviews.

It seems like Pytorch CPU degrades the performance.

Copy link
Member

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, @OmarManzoor!

Just one small nitpick from my side—otherwise, LGTM!

Edit: I forget we haven't checked with MPS. 😅

sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
@adrinjalali
Copy link
Member

I guess we're okay that PyTorch CPU has a subpar implementation?

@OmarManzoor
Copy link
Contributor Author

I guess we're okay that PyTorch CPU has a subpar implementation?

I think so. Just to be clear I tested on a kaggle kernel which sometimes involves differing and conflicting versions of packages but I think that is the best we can get for publicly available free gpus. Since metrics do not involve too much computation using Pytorch with CPU doesn't really offer much benefit and we are better off with the original numpy implementation when we are using a CPU.

But I think let's get an opinion from @ogrisel as well.

@OmarManzoor
Copy link
Contributor Author

@ogrisel Do you think this can be merged?

@lesteve
Copy link
Member

lesteve commented May 9, 2025

Let's merge this one, thanks!

I had a closer look at the performance difference between numpy and PyTorch CPU. Looking a bit at it, this is due to at least 2 things:

  • torch.unique being ~10 times slower than numpy.unique (20ms vs 200ms on an array of 1e7 integers)
  • being able to use sparse matrices in the numpy case

Quick benchmark code

import numpy as np
import torch

from sklearn.metrics import jaccard_score
import sklearn

y_true = np.ones((10_000, 1000), dtype=np.int64)
y_pred = np.ones((10_000, 1000), dtype=np.int64)
y_true_torch = torch.asarray(y_true)
y_pred_torch = torch.asarray(y_pred)

print('numpy')
%timeit jaccard_score(y_true, y_pred, average='macro')

print('torch')
with sklearn.config_context(array_api_dispatch=True):
    %timeit jaccard_score(y_true_torch, y_pred_torch, average='macro')

print('numpy unique')
%timeit np.unique(y_true)

print('torch unique')
%timeit torch.unique(y_true_torch)

print('numpy')
%prun -s cumulative jaccard_score(y_true, y_pred, average='macro')

print('torch')
with sklearn.config_context(array_api_dispatch=True):
    %prun -s cumulative jaccard_score(y_true_torch, y_pred_torch, average='macro')

Output:

numpy
1.04 s ± 17.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
torch
1.74 s ± 117 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
numpy unique
23.5 ms ± 857 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
torch unique
209 ms ± 7.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
numpy
         7508 function calls (7507 primitive calls) in 1.026 seconds
 
   Ordered by: cumulative time
 
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    1.027    1.027 {built-in method builtins.exec}
        1    0.000    0.000    1.027    1.027 <string>:1(<module>)
      2/1    0.001    0.000    1.027    1.027 _param_validation.py:187(wrapper)
        1    0.001    0.001    1.026    1.026 _classification.py:903(jaccard_score)
        2    0.002    0.001    0.743    0.372 _classification.py:70(_check_targets)
        7    0.002    0.000    0.741    0.106 _compressed.py:29(__init__)
        4    0.173    0.043    0.560    0.140 _coo.py:30(__init__)
        1    0.000    0.000    0.559    0.559 _classification.py:547(multilabel_confusion_matrix)
        1    0.000    0.000    0.466    0.466 _classification.py:1729(_check_set_wise_labels)
        4    0.261    0.065    0.261    0.065 {method 'nonzero' of 'numpy.ndarray' objects}
        8    0.003    0.000    0.187    0.023 _arraysetops_impl.py:145(unique)
        8    0.069    0.009    0.183    0.023 _arraysetops_impl.py:339(_unique1d)
        4    0.000    0.000    0.180    0.045 _coo.py:381(_coo_to_compressed)
        4    0.179    0.045    0.179    0.045 {built-in method scipy.sparse._sparsetools.coo_tocsr}
       12    0.000    0.000    0.099    0.008 _coo.py:93(<genexpr>)
       16    0.099    0.006    0.099    0.006 {method 'astype' of 'numpy.ndarray' objects}
        2    0.001    0.000    0.095    0.048 multiclass.py:41(unique_labels)
        4    0.000    0.000    0.094    0.023 _unique.py:28(attach_unique)
       12    0.000    0.000    0.094    0.008 _unique.py:56(<genexpr>)
        8    0.000    0.000    0.094    0.012 _unique.py:9(_attach_unique)
        8    0.000    0.000    0.094    0.012 multiclass.py:228(type_of_target)
        8    0.000    0.000    0.094    0.012 multiclass.py:129(is_multilabel)
        6    0.000    0.000    0.093    0.016 multiclass.py:79(<genexpr>)
        9    0.000    0.000    0.093    0.010 _internal.py:26(wrapped_f)
        4    0.000    0.000    0.093    0.023 _aliases.py:226(unique_values)
        8    0.082    0.010    0.082    0.010 {method 'flatten' of 'numpy.ndarray' objects}
        3    0.000    0.000    0.063    0.021 _array_api.py:951(_count_nonzero)
        3    0.063    0.021    0.063    0.021 sparsefuncs.py:602(count_nonzero)
        8    0.033    0.004    0.033    0.004 {method 'sort' of 'numpy.ndarray' objects}
        1    0.000    0.000    0.028    0.028 _compressed.py:391(multiply)
        1    0.000    0.000    0.028    0.028 _compressed.py:1355(_binopt)
        1    0.028    0.028    0.028    0.028 {built-in method scipy.sparse._sparsetools.csr_elmul_csr}
        4    0.000    0.000    0.027    0.007 _coo.py:202(_check)
       19    0.026    0.001    0.026    0.001 {method 'reduce' of 'numpy.ufunc' objects}
        8    0.000    0.000    0.013    0.002 {method 'min' of 'numpy.ndarray' objects}
        8    0.000    0.000    0.013    0.002 _methods.py:46(_amin)
        8    0.000    0.000    0.013    0.002 {method 'max' of 'numpy.ndarray' objects}
        8    0.000    0.000    0.013    0.002 _methods.py:42(_amax)
       12    0.000    0.000    0.001    0.000 validation.py:734(check_array)
        5    0.000    0.000    0.001    0.000 inspect.py:3347(signature)
        5    0.000    0.000    0.001    0.000 inspect.py:3068(from_callable)
        5    0.000    0.000    0.000    0.000 inspect.py:2479(_signature_from_callable)
       14    0.000    0.000    0.000    0.000 _base.py:1369(_get_index_dtype)
     2002    0.000    0.000    0.000    0.000 multiclass.py:116(<genexpr>)
       14    0.000    0.000    0.000    0.000 _sputils.py:264(get_index_dtype)
        5    0.000    0.000    0.000    0.000 inspect.py:2375(_signature_from_function)
        4    0.000    0.000    0.000    0.000 _aliases.py:169(_unique_kwargs)
     2415    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        6    0.000    0.000    0.000    0.000 multiclass.py:92(<genexpr>)
        6    0.000    0.000    0.000    0.000 multiclass.py:113(<genexpr>)
        4    0.000    0.000    0.000    0.000 multiclass.py:27(_unique_indicator)
       15    0.000    0.000    0.000    0.000 {built-in method numpy.empty}
       20    0.000    0.000    0.000    0.000 _aliases.py:78(asarray)
        8    0.000    0.000    0.000    0.000 validation.py:534(_ensure_sparse_format)
        7    0.000    0.000    0.000    0.000 _compressed.py:165(check_format)
       14    0.000    0.000    0.000    0.000 validation.py:381(_num_samples)
        1    0.000    0.000    0.000    0.000 _param_validation.py:28(validate_parameter_constraints)
       27    0.000    0.000    0.000    0.000 {built-in method numpy.array}
       56    0.000    0.000    0.000    0.000 _array_api.py:385(get_namespace)
        8    0.000    0.000    0.000    0.000 validation.py:90(_assert_all_finite)
       67    0.000    0.000    0.000    0.000 _config.py:35(get_config)
 torch
         7569 function calls (7526 primitive calls) in 1.787 seconds
 
   Ordered by: cumulative time
 
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    1.787    1.787 {built-in method builtins.exec}
        1    0.000    0.000    1.787    1.787 <string>:1(<module>)
      2/1    0.001    0.000    1.787    1.787 _param_validation.py:187(wrapper)
        1    0.000    0.000    1.786    1.786 _classification.py:903(jaccard_score)
        8    0.000    0.000    1.644    0.206 multiclass.py:228(type_of_target)
        8    0.000    0.000    1.643    0.205 multiclass.py:129(is_multilabel)
       10    0.000    0.000    1.642    0.164 _aliases.py:725(unique_values)
    20/10    0.000    0.000    1.642    0.164 _jit_internal.py:614(fn)
       10    0.000    0.000    1.641    0.164 functional.py:1068(_return_output)
       10    0.000    0.000    1.641    0.164 functional.py:813(_unique_impl)
       10    1.641    0.164    1.641    0.164 {built-in method torch._unique2}
        8    0.000    0.000    1.641    0.205 _unique.py:81(cached_unique)
       16    0.000    0.000    1.641    0.103 _unique.py:105(<genexpr>)
        8    0.000    0.000    1.641    0.205 _unique.py:62(_cached_unique)
        1    0.002    0.002    0.997    0.997 _classification.py:547(multilabel_confusion_matrix)
        2    0.000    0.000    0.826    0.413 _classification.py:70(_check_targets)
        2    0.000    0.000    0.822    0.411 multiclass.py:41(unique_labels)
        6    0.000    0.000    0.819    0.136 multiclass.py:79(<genexpr>)
        1    0.000    0.000    0.787    0.787 _classification.py:1729(_check_set_wise_labels)
        3    0.020    0.007    0.114    0.038 _array_api.py:951(_count_nonzero)
        3    0.000    0.000    0.047    0.016 _aliases.py:561(where)
        3    0.047    0.016    0.047    0.016 {built-in method torch.where}
        3    0.033    0.011    0.033    0.011 {built-in method torch.ones_like}
        2    0.000    0.000    0.018    0.009 _aliases.py:94(_f)
        1    0.018    0.018    0.018    0.018 {built-in method torch.multiply}
        3    0.000    0.000    0.013    0.004 _aliases.py:319(sum)
        3    0.013    0.004    0.013    0.004 {built-in method torch.sum}
       80    0.000    0.000    0.005    0.000 _array_api.py:385(get_namespace)
       50    0.000    0.000    0.003    0.000 _array_api.py:125(_check_array_api_dispatch)
       16    0.000    0.000    0.003    0.000 validation.py:734(check_array)
      100    0.000    0.000    0.003    0.000 version.py:65(parse)
      100    0.001    0.000    0.003    0.000 version.py:292(__init__)
        1    0.001    0.001    0.002    0.002 _array_api.py:999(_tolist)
        6    0.000    0.000    0.001    0.000 multiclass.py:92(<genexpr>)
        4    0.000    0.000    0.001    0.000 multiclass.py:27(_unique_indicator)
     1000    0.001    0.000    0.001    0.000 {method 'item' of 'numpy.generic' objects}
       49    0.000    0.000    0.001    0.000 _helpers.py:442(array_namespace)
        8    0.000    0.000    0.001    0.000 validation.py:90(_assert_all_finite)
       39    0.001    0.000    0.001    0.000 {built-in method torch.asarray}
      100    0.000    0.000    0.000    0.000 version.py:475(_cmpkey)
        2    0.000    0.000    0.000    0.000 _array_api.py:471(get_namespace_and_device)
      100    0.000    0.000    0.000    0.000 {method 'search' of 're.Pattern' objects}
      749    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
       96    0.000    0.000    0.000    0.000 _config.py:35(get_config)
        1    0.000    0.000    0.000    0.000 _classification.py:1674(_prf_divide)
       50    0.000    0.000    0.000    0.000 <frozen _collections_abc>:808(get)
       14    0.000    0.000    0.000    0.000 validation.py:381(_num_samples)
        1    0.000    0.000    0.000    0.000 _param_validation.py:28(validate_parameter_constraints)
     1000    0.000    0.000    0.000    0.000 {method 'group' of 're.Match' objects}
       50    0.000    0.000    0.000    0.000 <frozen os>:711(__getitem__)
       54    0.000    0.000    0.000    0.000 _array_api.py:349(_remove_non_arrays)
        1    0.000    0.000    0.000    0.000 inspect.py:3347(signature)
    52/20    0.000    0.000    0.000    0.000 _aliases.py:748(isdtype)
        1    0.000    0.000    0.000    0.000 inspect.py:3068(from_callable)
        5    0.000    0.000    0.000    0.000 _aliases.py:108(_fix_promotion)
        1    0.000    0.000    0.000    0.000 inspect.py:2479(_signature_from_callable)
       16    0.000    0.000    0.000    0.000 _array_api.py:734(_asarray_with_order)
        1    0.000    0.000    0.000    0.000 _array_api.py:604(_average)
      380    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
       24    0.000    0.000    0.000    0.000 warnings.py:170(simplefilter)
      400    0.000    0.000    0.000    0.000 version.py:302(<genexpr>)

@lesteve lesteve merged commit ffcd361 into scikit-learn:main May 9, 2025
36 checks passed
@OmarManzoor OmarManzoor deleted the array-api-jaccard branch May 9, 2025 09:37
@OmarManzoor
Copy link
Contributor Author

@lesteve Thanks for the benchmarks.

@lesteve
Copy link
Member

lesteve commented May 9, 2025

Maybe another part of the performance difference: the magic of computing unique values once and saving the result in the dtype metadata is numpy-specific. If arr is a torch array, arr.dtype has no attribute metadata.

Having said that I guess computing metrics is unlikely to be the bottleneck so I would say this is low priority for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.