@@ -215,9 +215,20 @@ cdef class PairwiseDistancesReduction:
215
215
not issparse(Y ) and Y.dtype == np.float64 and
216
216
metric in cls.valid_metrics())
217
217
218
+ # About __cinit__ and __init__ signatures:
219
+ #
220
+ # - __cinit__ is responsible for C-level allocations and initializations
221
+ # - __init__ is responsible for PyObject initialization
222
+ # - for a given class , __cinit__ and __init__ must have a matching signatures
223
+ # (up to *args and **kwargs )
224
+ # - for a given class hiearchy __cinit__'s must have a matching signatures
225
+ # (up to *args and **kwargs )
226
+ #
227
+ # See: https://cython.readthedocs.io/en/latest/src/userguide/special_methods.html#initialisation-methods-cinit-and-init #noqa
218
228
def __cinit__(
219
229
self ,
220
- DatasetsPair datasets_pair ,
230
+ n_samples_X ,
231
+ n_samples_Y ,
221
232
chunk_size = None ,
222
233
n_threads = None ,
223
234
strategy = None ,
@@ -234,9 +245,7 @@ cdef class PairwiseDistancesReduction:
234
245
235
246
self .effective_n_threads = _openmp_effective_n_threads(n_threads)
236
247
237
- self .datasets_pair = datasets_pair
238
-
239
- self .n_samples_X = datasets_pair.n_samples_X()
248
+ self .n_samples_X = n_samples_X
240
249
self .X_n_samples_chunk = min (self .n_samples_X, self .chunk_size)
241
250
X_n_full_chunks = self .n_samples_X // self .X_n_samples_chunk
242
251
X_n_samples_remainder = self .n_samples_X % self .X_n_samples_chunk
@@ -247,7 +256,7 @@ cdef class PairwiseDistancesReduction:
247
256
else :
248
257
self .X_n_samples_last_chunk = self .X_n_samples_chunk
249
258
250
- self .n_samples_Y = datasets_pair. n_samples_Y()
259
+ self .n_samples_Y = n_samples_Y
251
260
self .Y_n_samples_chunk = min (self .n_samples_Y, self .chunk_size)
252
261
Y_n_full_chunks = self .n_samples_Y // self .Y_n_samples_chunk
253
262
Y_n_samples_remainder = self .n_samples_Y % self .Y_n_samples_chunk
@@ -281,6 +290,17 @@ cdef class PairwiseDistancesReduction:
281
290
self .effective_n_threads,
282
291
)
283
292
293
+ def __init__ (
294
+ self ,
295
+ n_samples_X ,
296
+ n_samples_Y ,
297
+ chunk_size = None ,
298
+ n_threads = None ,
299
+ strategy = None ,
300
+ DatasetsPair datasets_pair = None ,
301
+ ):
302
+ self .datasets_pair = datasets_pair
303
+
284
304
@final
285
305
cdef void _parallel_on_X(self ) nogil:
286
306
""" Compute the pairwise distances of each row vector of X on Y
@@ -647,12 +667,30 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
647
667
# for various back-end and/or hardware and/or datatypes, and/or fused
648
668
# {sparse, dense}-datasetspair etc.
649
669
650
- pda = PairwiseDistancesArgKmin(
651
- datasets_pair = DatasetsPair.get_for(X, Y, metric, metric_kwargs),
652
- k = k,
653
- chunk_size = chunk_size,
654
- strategy = strategy,
655
- )
670
+ if metric in (" euclidean" , " sqeuclidean" ) and not issparse(X) and not issparse(Y):
671
+ use_squared_distances = metric == " sqeuclidean"
672
+ pda = FastEuclideanPairwiseDistancesArgKmin(
673
+ n_samples_X = X.shape[0 ],
674
+ n_samples_Y = Y.shape[0 ],
675
+ chunk_size = chunk_size,
676
+ n_threads = n_threads,
677
+ strategy = strategy,
678
+ X = X,
679
+ Y = Y,
680
+ k = k,
681
+ use_squared_distances = use_squared_distances,
682
+ metric_kwargs = metric_kwargs,
683
+ )
684
+ else :
685
+ pda = PairwiseDistancesArgKmin(
686
+ n_samples_X = X.shape[0 ],
687
+ n_samples_Y = Y.shape[0 ],
688
+ chunk_size = chunk_size,
689
+ n_threads = n_threads,
690
+ strategy = strategy,
691
+ k = k,
692
+ datasets_pair = DatasetsPair.get_for(X, Y, metric, metric_kwargs),
693
+ )
656
694
657
695
# Limit the number of threads in second level of nested parallelism for BLAS
658
696
# to avoid threads over-subscription (in GEMM for instance).
@@ -664,16 +702,16 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
664
702
665
703
return pda._finalize_results(return_distance)
666
704
667
- def __init__ (
705
+ def __cinit__ (
668
706
self ,
669
- DatasetsPair datasets_pair ,
707
+ n_samples_X ,
708
+ n_samples_Y ,
670
709
chunk_size = None ,
671
710
n_threads = None ,
672
711
strategy = None ,
673
- ITYPE_t k = 1 ,
674
- ):
675
- self .k = check_scalar(k, " k" , Integral, min_val = 1 )
676
-
712
+ *args ,
713
+ **kwargs ,
714
+ ):
677
715
# Allocating pointers to datastructures but not the datastructures themselves.
678
716
# There are as many pointers as effective threads.
679
717
#
@@ -690,6 +728,26 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
690
728
sizeof(ITYPE_t * ) * self .chunks_n_threads
691
729
)
692
730
731
+ def __init__ (
732
+ self ,
733
+ n_samples_X ,
734
+ n_samples_Y ,
735
+ chunk_size = None ,
736
+ n_threads = None ,
737
+ strategy = None ,
738
+ DatasetsPair datasets_pair = None ,
739
+ ITYPE_t k = 1 ,
740
+ ):
741
+ super ().__init__(
742
+ n_samples_X = n_samples_X,
743
+ n_samples_Y = n_samples_Y,
744
+ chunk_size = chunk_size,
745
+ n_threads = n_threads,
746
+ strategy = strategy,
747
+ datasets_pair = datasets_pair,
748
+ )
749
+ self .k = check_scalar(k, " k" , Integral, min_val = 1 )
750
+
693
751
# Main heaps which will be returned as results by `PairwiseDistancesArgKmin.compute`.
694
752
self .argkmin_indices = np.full((self .n_samples_X, self .k), 0 , dtype = ITYPE)
695
753
self .argkmin_distances = np.full((self .n_samples_X, self .k), DBL_MAX, dtype = DTYPE)
@@ -900,14 +958,32 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin):
900
958
return (PairwiseDistancesArgKmin.is_usable_for(X , Y , metric ) and
901
959
not _in_unstable_openblas_configuration())
902
960
961
+ def __cinit__(
962
+ self ,
963
+ n_samples_X ,
964
+ n_samples_Y ,
965
+ chunk_size = None ,
966
+ n_threads = None ,
967
+ strategy = None ,
968
+ *args ,
969
+ **kwargs ,
970
+ ):
971
+ # Temporary datastructures used in threads
972
+ self .dist_middle_terms_chunks = < DTYPE_t ** > malloc(
973
+ sizeof(DTYPE_t * ) * self .chunks_n_threads
974
+ )
975
+
903
976
def __init__ (
904
977
self ,
905
- X ,
906
- Y ,
907
- ITYPE_t k ,
908
- bint use_squared_distances = False ,
978
+ n_samples_X ,
979
+ n_samples_Y ,
909
980
chunk_size = None ,
981
+ n_threads = None ,
910
982
strategy = None ,
983
+ X = None ,
984
+ Y = None ,
985
+ ITYPE_t k = 1 ,
986
+ bint use_squared_distances = False ,
911
987
metric_kwargs = None ,
912
988
):
913
989
if metric_kwargs is not None and len (metric_kwargs) > 0 :
@@ -919,10 +995,14 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin):
919
995
)
920
996
921
997
super ().__init__(
998
+ n_samples_X = n_samples_X,
999
+ n_samples_Y = n_samples_Y,
1000
+ chunk_size = chunk_size,
1001
+ n_threads = n_threads,
1002
+ strategy = strategy,
922
1003
# The datasets pair here is used for exact distances computations
923
1004
datasets_pair = DatasetsPair.get_for(X, Y, metric = " euclidean" ),
924
1005
k = k,
925
- chunk_size = chunk_size,
926
1006
)
927
1007
# X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair
928
1008
cdef:
@@ -941,11 +1021,6 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin):
941
1021
)
942
1022
self .use_squared_distances = use_squared_distances
943
1023
944
- # Temporary datastructures used in threads
945
- self .dist_middle_terms_chunks = < DTYPE_t ** > malloc(
946
- sizeof(DTYPE_t * ) * self .chunks_n_threads
947
- )
948
-
949
1024
def __dealloc__ (self ):
950
1025
if self .dist_middle_terms_chunks is not NULL :
951
1026
free(self .dist_middle_terms_chunks)
0 commit comments