Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 33e1ea3

Browse filesBrowse files
jjerphanjiefangxuanyanthomasjpfan
committed
ENH Improve the creation of KDTree and BallTree on their worst-case time complexity (scikit-learn#19473)
Co-authored-by: jiefangxuanyan <505745416@qq.com> Co-authored-by: "Thomas J. Fan" <thomasjpfan@gmail.com>
1 parent aeef397 commit 33e1ea3
Copy full SHA for 33e1ea3

File tree

Expand file treeCollapse file tree

5 files changed

+149
-67
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+149
-67
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,16 @@ Changelog
278278
Use ``var_`` instead.
279279
:pr:`18842` by :user:`Hong Shao Yang <hongshaoyang>`.
280280

281+
282+
:mod:`sklearn.neighbors`
283+
..........................
284+
285+
- |Enhancement| The creation of :class:`neighbors.KDTree` and
286+
:class:`neighbors.BallTree` has been improved for their worst-cases time
287+
complexity from :math:`\mathcal{O}(n^2)` to :math:`\mathcal{O}(n)`.
288+
:pr:`19473` by :user:`jiefangxuanyan <jiefangxuanyan>` and
289+
:user:`Julien Jerphanion <jjerphan>`.
290+
281291
:mod:`sklearn.pipeline`
282292
.......................
283293

‎sklearn/neighbors/_binary_tree.pxi

Copy file name to clipboardExpand all lines: sklearn/neighbors/_binary_tree.pxi
+2-67Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ from ._typedefs import DTYPE, ITYPE
159159
from ._dist_metrics cimport (DistanceMetric, euclidean_dist, euclidean_rdist,
160160
euclidean_dist_to_rdist, euclidean_rdist_to_dist)
161161

162+
from ._partition_nodes cimport partition_node_indices
163+
162164
cdef extern from "numpy/arrayobject.h":
163165
void PyArray_ENABLEFLAGS(np.ndarray arr, int flags)
164166

@@ -776,73 +778,6 @@ cdef ITYPE_t find_node_split_dim(DTYPE_t* data,
776778
return j_max
777779

778780

779-
cdef int partition_node_indices(DTYPE_t* data,
780-
ITYPE_t* node_indices,
781-
ITYPE_t split_dim,
782-
ITYPE_t split_index,
783-
ITYPE_t n_features,
784-
ITYPE_t n_points) except -1:
785-
"""Partition points in the node into two equal-sized groups.
786-
787-
Upon return, the values in node_indices will be rearranged such that
788-
(assuming numpy-style indexing):
789-
790-
data[node_indices[0:split_index], split_dim]
791-
<= data[node_indices[split_index], split_dim]
792-
793-
and
794-
795-
data[node_indices[split_index], split_dim]
796-
<= data[node_indices[split_index:n_points], split_dim]
797-
798-
The algorithm is essentially a partial in-place quicksort around a
799-
set pivot.
800-
801-
Parameters
802-
----------
803-
data : double pointer
804-
Pointer to a 2D array of the training data, of shape [N, n_features].
805-
N must be greater than any of the values in node_indices.
806-
node_indices : int pointer
807-
Pointer to a 1D array of length n_points. This lists the indices of
808-
each of the points within the current node. This will be modified
809-
in-place.
810-
split_dim : int
811-
the dimension on which to split. This will usually be computed via
812-
the routine ``find_node_split_dim``
813-
split_index : int
814-
the index within node_indices around which to split the points.
815-
816-
Returns
817-
-------
818-
status : int
819-
integer exit status. On return, the contents of node_indices are
820-
modified as noted above.
821-
"""
822-
cdef ITYPE_t left, right, midindex, i
823-
cdef DTYPE_t d1, d2
824-
left = 0
825-
right = n_points - 1
826-
827-
while True:
828-
midindex = left
829-
for i in range(left, right):
830-
d1 = data[node_indices[i] * n_features + split_dim]
831-
d2 = data[node_indices[right] * n_features + split_dim]
832-
if d1 < d2:
833-
swap(node_indices, i, midindex)
834-
midindex += 1
835-
swap(node_indices, midindex, right)
836-
if midindex == split_index:
837-
break
838-
elif midindex < split_index:
839-
left = midindex + 1
840-
else:
841-
right = midindex - 1
842-
843-
return 0
844-
845-
846781
######################################################################
847782
# NodeHeap : min-heap used to keep track of nodes during
848783
# breadth-first query
+9Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from ._typedefs cimport DTYPE_t, ITYPE_t
2+
3+
cdef int partition_node_indices(
4+
DTYPE_t *data,
5+
ITYPE_t *node_indices,
6+
ITYPE_t split_dim,
7+
ITYPE_t split_index,
8+
ITYPE_t n_features,
9+
ITYPE_t n_points) except -1
+122Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# distutils : language = c++
2+
3+
# BinaryTrees rely on partial sorts to partition their nodes during their
4+
# initialisation.
5+
#
6+
# The C++ std library exposes nth_element, an efficient partial sort for this
7+
# situation which has a linear time complexity as well as the best performances.
8+
#
9+
# To use std::algorithm::nth_element, a few fixture are defined using Cython:
10+
# - partition_node_indices, a Cython function used in BinaryTrees, that calls
11+
# - partition_node_indices_inner, a C++ function that wraps nth_element and uses
12+
# - an IndexComparator to state how to compare KDTrees' indices
13+
#
14+
# IndexComparator has been defined so that partial sorts are stable with
15+
# respect to the nodes initial indices.
16+
#
17+
# See for reference:
18+
# - https://en.cppreference.com/w/cpp/algorithm/nth_element.
19+
# - https://github.com/scikit-learn/scikit-learn/pull/11103
20+
# - https://github.com/scikit-learn/scikit-learn/pull/19473
21+
22+
cdef extern from *:
23+
"""
24+
#include <algorithm>
25+
26+
template<class D, class I>
27+
class IndexComparator {
28+
private:
29+
const D *data;
30+
I split_dim, n_features;
31+
public:
32+
IndexComparator(const D *data, const I &split_dim, const I &n_features):
33+
data(data), split_dim(split_dim), n_features(n_features) {}
34+
35+
bool operator()(const I &a, const I &b) const {
36+
D a_value = data[a * n_features + split_dim];
37+
D b_value = data[b * n_features + split_dim];
38+
return a_value == b_value ? a < b : a_value < b_value;
39+
}
40+
};
41+
42+
template<class D, class I>
43+
void partition_node_indices_inner(
44+
const D *data,
45+
I *node_indices,
46+
const I &split_dim,
47+
const I &split_index,
48+
const I &n_features,
49+
const I &n_points) {
50+
IndexComparator<D, I> index_comparator(data, split_dim, n_features);
51+
std::nth_element(
52+
node_indices,
53+
node_indices + split_index,
54+
node_indices + n_points,
55+
index_comparator);
56+
}
57+
"""
58+
void partition_node_indices_inner[D, I](
59+
D *data,
60+
I *node_indices,
61+
I split_dim,
62+
I split_index,
63+
I n_features,
64+
I n_points) except +
65+
66+
67+
cdef int partition_node_indices(
68+
DTYPE_t *data,
69+
ITYPE_t *node_indices,
70+
ITYPE_t split_dim,
71+
ITYPE_t split_index,
72+
ITYPE_t n_features,
73+
ITYPE_t n_points) except -1:
74+
"""Partition points in the node into two equal-sized groups.
75+
76+
Upon return, the values in node_indices will be rearranged such that
77+
(assuming numpy-style indexing):
78+
79+
data[node_indices[0:split_index], split_dim]
80+
<= data[node_indices[split_index], split_dim]
81+
82+
and
83+
84+
data[node_indices[split_index], split_dim]
85+
<= data[node_indices[split_index:n_points], split_dim]
86+
87+
The algorithm is essentially a partial in-place quicksort around a
88+
set pivot.
89+
90+
Parameters
91+
----------
92+
data : double pointer
93+
Pointer to a 2D array of the training data, of shape [N, n_features].
94+
N must be greater than any of the values in node_indices.
95+
node_indices : int pointer
96+
Pointer to a 1D array of length n_points. This lists the indices of
97+
each of the points within the current node. This will be modified
98+
in-place.
99+
split_dim : int
100+
the dimension on which to split. This will usually be computed via
101+
the routine ``find_node_split_dim``.
102+
split_index : int
103+
the index within node_indices around which to split the points.
104+
n_features: int
105+
the number of features (i.e columns) in the 2D array pointed by data.
106+
n_points : int
107+
the length of node_indices. This is also the number of points in
108+
the original dataset.
109+
Returns
110+
-------
111+
status : int
112+
integer exit status. On return, the contents of node_indices are
113+
modified as noted above.
114+
"""
115+
partition_node_indices_inner(
116+
data,
117+
node_indices,
118+
split_dim,
119+
split_index,
120+
n_features,
121+
n_points)
122+
return 0

‎sklearn/neighbors/setup.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/setup.py
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ def configuration(parent_package='', top_path=None):
2020
include_dirs=[numpy.get_include()],
2121
libraries=libraries)
2222

23+
config.add_extension('_partition_nodes',
24+
sources=['_partition_nodes.pyx'],
25+
include_dirs=[numpy.get_include()],
26+
language="c++",
27+
libraries=libraries)
28+
2329
config.add_extension('_dist_metrics',
2430
sources=['_dist_metrics.pyx'],
2531
include_dirs=[numpy.get_include(),

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.