File tree 2 files changed +33
-2
lines changed
Filter options
_pairwise_distances_reduction 2 files changed +33
-2
lines changed
Original file line number Diff line number Diff line change @@ -110,7 +110,7 @@ def is_valid_sparse_matrix(X):
110
110
X .indices .dtype == X .indptr .dtype == np .int32
111
111
)
112
112
113
- return (
113
+ is_usable = (
114
114
get_config ().get ("enable_cython_pairwise_dist" , True )
115
115
and (is_numpy_c_ordered (X ) or is_valid_sparse_matrix (X ))
116
116
and (is_numpy_c_ordered (Y ) or is_valid_sparse_matrix (Y ))
@@ -119,6 +119,24 @@ def is_valid_sparse_matrix(X):
119
119
and metric in cls .valid_metrics ()
120
120
)
121
121
122
+ # The other joblib-based back-end might be more efficient on fused sparse-dense
123
+ # datasets' pairs on metric="(sq)euclidean" for some configurations because it
124
+ # uses the Squared Euclidean matrix decomposition, i.e.:
125
+ #
126
+ # ||X_c_i - Y_c_j||² = ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||²
127
+ #
128
+ # calling efficient sparse-dense routines for matrix and vectors multiplication
129
+ # implemented in SciPy we do not use yet here.
130
+ # See: https://github.com/scikit-learn/scikit-learn/pull/23585#issuecomment-1247996669 # noqa
131
+ # TODO: implement specialisation for (sq)euclidean on fused sparse-dense
132
+ # using sparse-dense routines for matrix-vector multiplications.
133
+ fused_sparse_dense_euclidean_case_guard = not (
134
+ (is_valid_sparse_matrix (X ) or is_valid_sparse_matrix (Y ))
135
+ and "euclidean" in metric
136
+ )
137
+
138
+ return is_usable and fused_sparse_dense_euclidean_case_guard
139
+
122
140
@classmethod
123
141
@abstractmethod
124
142
def compute (
Original file line number Diff line number Diff line change @@ -518,7 +518,7 @@ def test_pairwise_distances_reduction_is_usable_for():
518
518
Y = rng .rand (100 , 10 )
519
519
X_csr = csr_matrix (X )
520
520
Y_csr = csr_matrix (Y )
521
- metric = "euclidean "
521
+ metric = "manhattan "
522
522
523
523
# Must be usable for all possible pair of {dense, sparse} datasets
524
524
assert BaseDistanceReductionDispatcher .is_usable_for (X , Y , metric )
@@ -551,6 +551,19 @@ def test_pairwise_distances_reduction_is_usable_for():
551
551
np .asfortranarray (X ), Y , metric
552
552
)
553
553
554
+ # We prefer not to use those implementations for fused sparse-dense when
555
+ # metric="(sq)euclidean" because it's not yet the most efficient one on
556
+ # all configurations of datasets.
557
+ # See: https://github.com/scikit-learn/scikit-learn/pull/23585#issuecomment-1247996669 # noqa
558
+ # TODO: implement specialisation for (sq)euclidean on fused sparse-dense
559
+ # using sparse-dense routines for matrix-vector multiplications.
560
+ assert not BaseDistanceReductionDispatcher .is_usable_for (
561
+ X_csr , Y , metric = "euclidean"
562
+ )
563
+ assert not BaseDistanceReductionDispatcher .is_usable_for (
564
+ X_csr , Y_csr , metric = "sqeuclidean"
565
+ )
566
+
554
567
# CSR matrices without non-zeros elements aren't currently supported
555
568
# TODO: support CSR matrices without non-zeros elements
556
569
X_csr_0_nnz = csr_matrix (X * 0 )
You can’t perform that action at this time.
0 commit comments