Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d42e6bd

Browse filesBrowse files
jjerphanogriseljeremiedbbthomasjpfan
committed
MAINT Create private _pairwise_distances_reductions submodule (scikit-learn#23724)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 7285e5c commit d42e6bd
Copy full SHA for d42e6bd
Expand file treeCollapse file tree

18 files changed

+2593
-2186
lines changed

‎sklearn/metrics/_dist_metrics.pxd.tp

Copy file name to clipboardExpand all lines: sklearn/metrics/_dist_metrics.pxd.tp
-20Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,3 @@ cdef class DistanceMetric{{name_suffix}}:
101101
cdef DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1
102102

103103
{{endfor}}
104-
105-
######################################################################
106-
# DatasetsPair base class
107-
cdef class DatasetsPair:
108-
cdef DistanceMetric distance_metric
109-
110-
cdef ITYPE_t n_samples_X(self) nogil
111-
112-
cdef ITYPE_t n_samples_Y(self) nogil
113-
114-
cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil
115-
116-
cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil
117-
118-
119-
cdef class DenseDenseDatasetsPair(DatasetsPair):
120-
cdef:
121-
const DTYPE_t[:, ::1] X
122-
const DTYPE_t[:, ::1] Y
123-
ITYPE_t d

‎sklearn/metrics/_dist_metrics.pyx.tp

Copy file name to clipboardExpand all lines: sklearn/metrics/_dist_metrics.pyx.tp
-161Lines changed: 0 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ implementation_specific_values = [
3232

3333
import numpy as np
3434
cimport numpy as cnp
35-
from cython cimport final
3635

3736
cnp.import_array() # required in order to use C-API
3837

@@ -1171,163 +1170,3 @@ cdef class PyFuncDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):
11711170
"vectors and return a float.")
11721171

11731172
{{endfor}}
1174-
1175-
######################################################################
1176-
# Datasets Pair Classes
1177-
cdef class DatasetsPair:
1178-
"""Abstract class which wraps a pair of datasets (X, Y).
1179-
1180-
This class allows computing distances between a single pair of rows of
1181-
of X and Y at a time given the pair of their indices (i, j). This class is
1182-
specialized for each metric thanks to the :func:`get_for` factory classmethod.
1183-
1184-
The handling of parallelization over chunks to compute the distances
1185-
and aggregation for several rows at a time is done in dedicated
1186-
subclasses of PairwiseDistancesReduction that in-turn rely on
1187-
subclasses of DatasetsPair for each pair of rows in the data. The goal
1188-
is to make it possible to decouple the generic parallelization and
1189-
aggregation logic from metric-specific computation as much as
1190-
possible.
1191-
1192-
X and Y can be stored as C-contiguous np.ndarrays or CSR matrices
1193-
in subclasses.
1194-
1195-
This class avoids the overhead of dispatching distance computations
1196-
to :class:`sklearn.metrics.DistanceMetric` based on the physical
1197-
representation of the vectors (sparse vs. dense). It makes use of
1198-
cython.final to remove the overhead of dispatching method calls.
1199-
1200-
Parameters
1201-
----------
1202-
distance_metric: DistanceMetric
1203-
The distance metric responsible for computing distances
1204-
between two vectors of (X, Y).
1205-
"""
1206-
1207-
@classmethod
1208-
def get_for(
1209-
cls,
1210-
X,
1211-
Y,
1212-
str metric="euclidean",
1213-
dict metric_kwargs=None,
1214-
) -> DatasetsPair:
1215-
"""Return the DatasetsPair implementation for the given arguments.
1216-
1217-
Parameters
1218-
----------
1219-
X : {ndarray, sparse matrix} of shape (n_samples_X, n_features)
1220-
Input data.
1221-
If provided as a ndarray, it must be C-contiguous.
1222-
If provided as a sparse matrix, it must be in CSR format.
1223-
1224-
Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features)
1225-
Input data.
1226-
If provided as a ndarray, it must be C-contiguous.
1227-
If provided as a sparse matrix, it must be in CSR format.
1228-
1229-
metric : str, default='euclidean'
1230-
The distance metric to compute between rows of X and Y.
1231-
The default metric is a fast implementation of the Euclidean
1232-
metric. For a list of available metrics, see the documentation
1233-
of :class:`~sklearn.metrics.DistanceMetric`.
1234-
1235-
metric_kwargs : dict, default=None
1236-
Keyword arguments to pass to specified metric function.
1237-
1238-
Returns
1239-
-------
1240-
datasets_pair: DatasetsPair
1241-
The suited DatasetsPair implementation.
1242-
"""
1243-
cdef:
1244-
DistanceMetric distance_metric = DistanceMetric.get_metric(
1245-
metric,
1246-
**(metric_kwargs or {})
1247-
)
1248-
1249-
if not(X.dtype == Y.dtype == np.float64):
1250-
raise ValueError(
1251-
f"Only 64bit float datasets are supported at this time, "
1252-
f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}."
1253-
)
1254-
1255-
# Metric-specific checks that do not replace nor duplicate `check_array`.
1256-
distance_metric._validate_data(X)
1257-
distance_metric._validate_data(Y)
1258-
1259-
# TODO: dispatch to other dataset pairs for sparse support once available:
1260-
if issparse(X) or issparse(Y):
1261-
raise ValueError("Only dense datasets are supported for X and Y.")
1262-
1263-
return DenseDenseDatasetsPair(X, Y, distance_metric)
1264-
1265-
def __init__(self, DistanceMetric distance_metric):
1266-
self.distance_metric = distance_metric
1267-
1268-
cdef ITYPE_t n_samples_X(self) nogil:
1269-
"""Number of samples in X."""
1270-
# This is a abstract method.
1271-
# This _must_ always be overwritten in subclasses.
1272-
# TODO: add "with gil: raise" here when supporting Cython 3.0
1273-
return -999
1274-
1275-
cdef ITYPE_t n_samples_Y(self) nogil:
1276-
"""Number of samples in Y."""
1277-
# This is a abstract method.
1278-
# This _must_ always be overwritten in subclasses.
1279-
# TODO: add "with gil: raise" here when supporting Cython 3.0
1280-
return -999
1281-
1282-
cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil:
1283-
return self.dist(i, j)
1284-
1285-
cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil:
1286-
# This is a abstract method.
1287-
# This _must_ always be overwritten in subclasses.
1288-
# TODO: add "with gil: raise" here when supporting Cython 3.0
1289-
return -1
1290-
1291-
@final
1292-
cdef class DenseDenseDatasetsPair(DatasetsPair):
1293-
"""Compute distances between row vectors of two arrays.
1294-
1295-
Parameters
1296-
----------
1297-
X: ndarray of shape (n_samples_X, n_features)
1298-
Rows represent vectors. Must be C-contiguous.
1299-
1300-
Y: ndarray of shape (n_samples_Y, n_features)
1301-
Rows represent vectors. Must be C-contiguous.
1302-
1303-
distance_metric: DistanceMetric
1304-
The distance metric responsible for computing distances
1305-
between two row vectors of (X, Y).
1306-
"""
1307-
1308-
def __init__(self, X, Y, DistanceMetric distance_metric):
1309-
super().__init__(distance_metric)
1310-
# Arrays have already been checked
1311-
self.X = X
1312-
self.Y = Y
1313-
self.d = X.shape[1]
1314-
1315-
@final
1316-
cdef ITYPE_t n_samples_X(self) nogil:
1317-
return self.X.shape[0]
1318-
1319-
@final
1320-
cdef ITYPE_t n_samples_Y(self) nogil:
1321-
return self.Y.shape[0]
1322-
1323-
@final
1324-
cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil:
1325-
return self.distance_metric.rdist(&self.X[i, 0],
1326-
&self.Y[j, 0],
1327-
self.d)
1328-
1329-
@final
1330-
cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil:
1331-
return self.distance_metric.dist(&self.X[i, 0],
1332-
&self.Y[j, 0],
1333-
self.d)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.