Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ffcd361

Browse filesBrowse files
authored
FEA Add array api support for jaccard score (scikit-learn#31204)
1 parent 27f2af3 commit ffcd361
Copy full SHA for ffcd361

File tree

4 files changed

+13
-4
lines changed
Filter options

4 files changed

+13
-4
lines changed

‎doc/modules/array_api.rst

Copy file name to clipboardExpand all lines: doc/modules/array_api.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ Metrics
139139
- :func:`sklearn.metrics.f1_score`
140140
- :func:`sklearn.metrics.fbeta_score`
141141
- :func:`sklearn.metrics.hamming_loss`
142+
- :func:`sklearn.metrics.jaccard_score`
142143
- :func:`sklearn.metrics.max_error`
143144
- :func:`sklearn.metrics.mean_absolute_error`
144145
- :func:`sklearn.metrics.mean_absolute_percentage_error`
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- :func:`sklearn.metrics.jaccard_score` now supports Array API compatible inputs.
2+
By :user:`Omar Salman <OmarManzoor>`

‎sklearn/metrics/_classification.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_classification.py
+5-4Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,9 +1071,10 @@ def jaccard_score(
10711071
numerator = MCM[:, 1, 1]
10721072
denominator = MCM[:, 1, 1] + MCM[:, 0, 1] + MCM[:, 1, 0]
10731073

1074+
xp, _, device_ = get_namespace_and_device(y_true, y_pred)
10741075
if average == "micro":
1075-
numerator = np.array([numerator.sum()])
1076-
denominator = np.array([denominator.sum()])
1076+
numerator = xp.asarray(xp.sum(numerator, keepdims=True), device=device_)
1077+
denominator = xp.asarray(xp.sum(denominator, keepdims=True), device=device_)
10771078

10781079
jaccard = _prf_divide(
10791080
numerator,
@@ -1088,14 +1089,14 @@ def jaccard_score(
10881089
return jaccard
10891090
if average == "weighted":
10901091
weights = MCM[:, 1, 0] + MCM[:, 1, 1]
1091-
if not np.any(weights):
1092+
if not xp.any(weights):
10921093
# numerator is 0, and warning should have already been issued
10931094
weights = None
10941095
elif average == "samples" and sample_weight is not None:
10951096
weights = sample_weight
10961097
else:
10971098
weights = None
1098-
return float(np.average(jaccard, weights=weights))
1099+
return float(_average(jaccard, weights=weights, xp=xp))
10991100

11001101

11011102
@validate_params(

‎sklearn/metrics/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_common.py
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,11 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
21472147
check_array_api_multiclass_classification_metric,
21482148
check_array_api_multilabel_classification_metric,
21492149
],
2150+
jaccard_score: [
2151+
check_array_api_binary_classification_metric,
2152+
check_array_api_multiclass_classification_metric,
2153+
check_array_api_multilabel_classification_metric,
2154+
],
21502155
multilabel_confusion_matrix: [
21512156
check_array_api_binary_classification_metric,
21522157
check_array_api_multiclass_classification_metric,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.