Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 03b6c79

Browse filesBrowse files
jjerphanthomasjpfanogrisel
committed
MAINT Reorder initilizations to move allocations in __cinit__
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 2f61278 commit 03b6c79
Copy full SHA for 03b6c79

File tree

Expand file treeCollapse file tree

1 file changed

+102
-27
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+102
-27
lines changed

‎sklearn/metrics/_pairwise_distances_reduction.pyx

Copy file name to clipboardExpand all lines: sklearn/metrics/_pairwise_distances_reduction.pyx
+102-27Lines changed: 102 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,20 @@ cdef class PairwiseDistancesReduction:
215215
not issparse(Y) and Y.dtype == np.float64 and
216216
metric in cls.valid_metrics())
217217

218+
# About __cinit__ and __init__ signatures:
219+
#
220+
# - __cinit__ is responsible for C-level allocations and initializations
221+
# - __init__ is responsible for PyObject initialization
222+
# - for a given class, __cinit__ and __init__ must have a matching signatures
223+
# (up to *args and **kwargs)
224+
# - for a given class hiearchy __cinit__'s must have a matching signatures
225+
# (up to *args and **kwargs)
226+
#
227+
# See: https://cython.readthedocs.io/en/latest/src/userguide/special_methods.html#initialisation-methods-cinit-and-init #noqa
218228
def __cinit__(
219229
self,
220-
DatasetsPair datasets_pair,
230+
n_samples_X,
231+
n_samples_Y,
221232
chunk_size=None,
222233
n_threads=None,
223234
strategy=None,
@@ -234,9 +245,7 @@ cdef class PairwiseDistancesReduction:
234245

235246
self.effective_n_threads = _openmp_effective_n_threads(n_threads)
236247

237-
self.datasets_pair = datasets_pair
238-
239-
self.n_samples_X = datasets_pair.n_samples_X()
248+
self.n_samples_X = n_samples_X
240249
self.X_n_samples_chunk = min(self.n_samples_X, self.chunk_size)
241250
X_n_full_chunks = self.n_samples_X // self.X_n_samples_chunk
242251
X_n_samples_remainder = self.n_samples_X % self.X_n_samples_chunk
@@ -247,7 +256,7 @@ cdef class PairwiseDistancesReduction:
247256
else:
248257
self.X_n_samples_last_chunk = self.X_n_samples_chunk
249258

250-
self.n_samples_Y = datasets_pair.n_samples_Y()
259+
self.n_samples_Y = n_samples_Y
251260
self.Y_n_samples_chunk = min(self.n_samples_Y, self.chunk_size)
252261
Y_n_full_chunks = self.n_samples_Y // self.Y_n_samples_chunk
253262
Y_n_samples_remainder = self.n_samples_Y % self.Y_n_samples_chunk
@@ -281,6 +290,17 @@ cdef class PairwiseDistancesReduction:
281290
self.effective_n_threads,
282291
)
283292

293+
def __init__(
294+
self,
295+
n_samples_X,
296+
n_samples_Y,
297+
chunk_size=None,
298+
n_threads=None,
299+
strategy=None,
300+
DatasetsPair datasets_pair=None,
301+
):
302+
self.datasets_pair = datasets_pair
303+
284304
@final
285305
cdef void _parallel_on_X(self) nogil:
286306
"""Compute the pairwise distances of each row vector of X on Y
@@ -647,12 +667,30 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
647667
# for various back-end and/or hardware and/or datatypes, and/or fused
648668
# {sparse, dense}-datasetspair etc.
649669

650-
pda = PairwiseDistancesArgKmin(
651-
datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs),
652-
k=k,
653-
chunk_size=chunk_size,
654-
strategy=strategy,
655-
)
670+
if metric in ("euclidean", "sqeuclidean") and not issparse(X) and not issparse(Y):
671+
use_squared_distances = metric == "sqeuclidean"
672+
pda = FastEuclideanPairwiseDistancesArgKmin(
673+
n_samples_X=X.shape[0],
674+
n_samples_Y=Y.shape[0],
675+
chunk_size=chunk_size,
676+
n_threads=n_threads,
677+
strategy=strategy,
678+
X=X,
679+
Y=Y,
680+
k=k,
681+
use_squared_distances=use_squared_distances,
682+
metric_kwargs=metric_kwargs,
683+
)
684+
else:
685+
pda = PairwiseDistancesArgKmin(
686+
n_samples_X=X.shape[0],
687+
n_samples_Y=Y.shape[0],
688+
chunk_size=chunk_size,
689+
n_threads=n_threads,
690+
strategy=strategy,
691+
k=k,
692+
datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs),
693+
)
656694

657695
# Limit the number of threads in second level of nested parallelism for BLAS
658696
# to avoid threads over-subscription (in GEMM for instance).
@@ -664,16 +702,16 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
664702

665703
return pda._finalize_results(return_distance)
666704

667-
def __init__(
705+
def __cinit__(
668706
self,
669-
DatasetsPair datasets_pair,
707+
n_samples_X,
708+
n_samples_Y,
670709
chunk_size=None,
671710
n_threads=None,
672711
strategy=None,
673-
ITYPE_t k=1,
674-
):
675-
self.k = check_scalar(k, "k", Integral, min_val=1)
676-
712+
*args,
713+
**kwargs,
714+
):
677715
# Allocating pointers to datastructures but not the datastructures themselves.
678716
# There are as many pointers as effective threads.
679717
#
@@ -690,6 +728,26 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
690728
sizeof(ITYPE_t *) * self.chunks_n_threads
691729
)
692730

731+
def __init__(
732+
self,
733+
n_samples_X,
734+
n_samples_Y,
735+
chunk_size=None,
736+
n_threads=None,
737+
strategy=None,
738+
DatasetsPair datasets_pair=None,
739+
ITYPE_t k=1,
740+
):
741+
super().__init__(
742+
n_samples_X=n_samples_X,
743+
n_samples_Y=n_samples_Y,
744+
chunk_size=chunk_size,
745+
n_threads=n_threads,
746+
strategy=strategy,
747+
datasets_pair=datasets_pair,
748+
)
749+
self.k = check_scalar(k, "k", Integral, min_val=1)
750+
693751
# Main heaps which will be returned as results by `PairwiseDistancesArgKmin.compute`.
694752
self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE)
695753
self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE)
@@ -900,14 +958,32 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin):
900958
return (PairwiseDistancesArgKmin.is_usable_for(X, Y, metric) and
901959
not _in_unstable_openblas_configuration())
902960

961+
def __cinit__(
962+
self,
963+
n_samples_X,
964+
n_samples_Y,
965+
chunk_size=None,
966+
n_threads=None,
967+
strategy=None,
968+
*args,
969+
**kwargs,
970+
):
971+
# Temporary datastructures used in threads
972+
self.dist_middle_terms_chunks = <DTYPE_t **> malloc(
973+
sizeof(DTYPE_t *) * self.chunks_n_threads
974+
)
975+
903976
def __init__(
904977
self,
905-
X,
906-
Y,
907-
ITYPE_t k,
908-
bint use_squared_distances=False,
978+
n_samples_X,
979+
n_samples_Y,
909980
chunk_size=None,
981+
n_threads=None,
910982
strategy=None,
983+
X=None,
984+
Y=None,
985+
ITYPE_t k=1,
986+
bint use_squared_distances=False,
911987
metric_kwargs=None,
912988
):
913989
if metric_kwargs is not None and len(metric_kwargs) > 0:
@@ -919,10 +995,14 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin):
919995
)
920996

921997
super().__init__(
998+
n_samples_X=n_samples_X,
999+
n_samples_Y=n_samples_Y,
1000+
chunk_size=chunk_size,
1001+
n_threads=n_threads,
1002+
strategy=strategy,
9221003
# The datasets pair here is used for exact distances computations
9231004
datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"),
9241005
k=k,
925-
chunk_size=chunk_size,
9261006
)
9271007
# X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair
9281008
cdef:
@@ -941,11 +1021,6 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin):
9411021
)
9421022
self.use_squared_distances = use_squared_distances
9431023

944-
# Temporary datastructures used in threads
945-
self.dist_middle_terms_chunks = <DTYPE_t **> malloc(
946-
sizeof(DTYPE_t *) * self.chunks_n_threads
947-
)
948-
9491024
def __dealloc__(self):
9501025
if self.dist_middle_terms_chunks is not NULL:
9511026
free(self.dist_middle_terms_chunks)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.