Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1813b4a

Browse filesBrowse files
authored
array API support for cosine_distances (#29265)
1 parent cc97b80 commit 1813b4a
Copy full SHA for 1813b4a

File tree

Expand file treeCollapse file tree

5 files changed

+22
-2
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+22
-2
lines changed

‎doc/modules/array_api.rst

Copy file name to clipboardExpand all lines: doc/modules/array_api.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ Metrics
123123
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
124124
- :func:`sklearn.metrics.pairwise.chi2_kernel`
125125
- :func:`sklearn.metrics.pairwise.cosine_similarity`
126+
- :func:`sklearn.metrics.pairwise.cosine_distances`
126127
- :func:`sklearn.metrics.pairwise.euclidean_distances` (see :ref:`device_support_for_float64`)
127128
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
128129
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)

‎doc/whats_new/v1.6.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.6.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ See :ref:`array_api` for more details.
4343
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
4444
- :func:`sklearn.metrics.pairwise.chi2_kernel` :pr:`29267` by :user:`Yaroslav Korobko <Tialo>`;
4545
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`;
46+
- :func:`sklearn.metrics.pairwise.cosine_distances` :pr:`29265` by :user:`Emily Chen <EmilyXinyi>`;
4647
- :func:`sklearn.metrics.pairwise.euclidean_distances` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`;
4748
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`;
4849
- :func:`sklearn.metrics.pairwise.rbf_kernel` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`.

‎sklearn/metrics/pairwise.py

Copy file name to clipboardExpand all lines: sklearn/metrics/pairwise.py
+5-2Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
gen_even_slices,
2323
)
2424
from ..utils._array_api import (
25+
_clip,
2526
_fill_or_add_to_diagonal,
2627
_find_matching_floating_dtype,
2728
_is_numpy_namespace,
@@ -1139,15 +1140,17 @@ def cosine_distances(X, Y=None):
11391140
array([[1. , 1. ],
11401141
[0.42..., 0.18...]])
11411142
"""
1143+
xp, _ = get_namespace(X, Y)
1144+
11421145
# 1.0 - cosine_similarity(X, Y) without copy
11431146
S = cosine_similarity(X, Y)
11441147
S *= -1
11451148
S += 1
1146-
np.clip(S, 0, 2, out=S)
1149+
S = _clip(S, 0, 2, xp)
11471150
if X is Y or Y is None:
11481151
# Ensure that distances between vectors and themselves are set to 0.0.
11491152
# This may not be the case due to floating point rounding errors.
1150-
np.fill_diagonal(S, 0.0)
1153+
_fill_or_add_to_diagonal(S, 0.0, xp, add_value=False)
11511154
return S
11521155

11531156

‎sklearn/metrics/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_common.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from sklearn.metrics.pairwise import (
5555
additive_chi2_kernel,
5656
chi2_kernel,
57+
cosine_distances,
5758
cosine_similarity,
5859
euclidean_distances,
5960
paired_cosine_distances,
@@ -2016,6 +2017,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
20162017
mean_gamma_deviance: [check_array_api_regression_metric],
20172018
max_error: [check_array_api_regression_metric],
20182019
chi2_kernel: [check_array_api_metric_pairwise],
2020+
cosine_distances: [check_array_api_metric_pairwise],
20192021
euclidean_distances: [check_array_api_metric_pairwise],
20202022
rbf_kernel: [check_array_api_metric_pairwise],
20212023
}

‎sklearn/utils/_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/_array_api.py
+13Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,19 @@ def _nanmax(X, axis=None, xp=None):
791791
return X
792792

793793

794+
def _clip(S, min_val, max_val, xp):
795+
# TODO: remove this method and change all usage once we move to array api 2023.12
796+
# https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.clip.html#clip
797+
if _is_numpy_namespace(xp):
798+
return numpy.clip(S, min_val, max_val)
799+
else:
800+
min_arr = xp.asarray(min_val, dtype=S.dtype)
801+
max_arr = xp.asarray(max_val, dtype=S.dtype)
802+
S = xp.where(S < min_arr, min_arr, S)
803+
S = xp.where(S > max_arr, max_arr, S)
804+
return S
805+
806+
794807
def _asarray_with_order(
795808
array, dtype=None, order=None, copy=None, *, xp=None, device=None
796809
):

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.