-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
PERF Implement PairwiseDistancesReduction
backend for KNeighbors.predict_proba
#24076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
73 commits
Select commit
Hold shift + click to select a range
8a99217
Partial implementation, still broken
Micky774 c071338
Major update, mostly working
Micky774 cbe5347
Merge branch 'main' into pwd_kncp
Micky774 a995220
Completed X-parallel implementation
Micky774 052e59b
Improved documentation
Micky774 fc8d495
First batch of review feedback
Micky774 6334581
Created inline helper function for weighted mode
Micky774 6f45f79
Merge branch 'main' into pwd_kncp
Micky774 db31c78
Multioutput support with probabilities output refactor
Micky774 b96cab4
Merge branch 'main' into pwd_kncp
Micky774 59f60ed
Code simplification and cleanup
Micky774 a343b58
Code simplification
Micky774 e419f96
Merge branch 'main' into pwd_kncp
Micky774 c01d37f
Update sklearn/metrics/_pairwise_distances_reduction/_argkminlabels.pyx
Micky774 2b128ea
Merge branch 'main' into pwd_kncp
Micky774 efaa053
Merge branch 'pwd_kncp' of https://github.com/Micky774/scikit-learn i…
Micky774 7aff02b
Fixed implementation
Micky774 91a5b25
Removed extraneous extern statements
Micky774 2fdbd35
Merge branch 'main' into pwd_kncp
Micky774 941d42c
Updated to adopt templating for float32
Micky774 57ffa8b
Removed now-generated file
Micky774 883152e
Merge branch 'main' into pwd_kncp
Micky774 003e1bf
Merge branch 'main' into pwd_kncp
Micky774 0354d73
Merge branch 'main' into pwd_kncp
Micky774 6070a3d
Merge branch 'main' into pwd_kncp
jjerphan c5ce440
Merge branch 'main' into pwd_kncp
Micky774 7b6eae7
Added Euclidean specialization
Micky774 8dcd3c9
Updated typing for `labels` and associated variables
Micky774 3a2564a
Merge branch 'main' into pwd_kncp
Micky774 5717e34
Merge branch 'main' into pwd_kncp
jjerphan f556084
Reduced labels/indices from int64 to int32
Micky774 9c35e6e
Merge branch 'pwd_kncp' of https://github.com/Micky774/scikit-learn i…
Micky774 73012f4
Merge branch 'main' into pwd_kncp
Micky774 f47bab2
Revert "Reduced labels/indices from int64 to int32"
Micky774 9c2b7b4
Updated metric resolution and fixed dtype
Micky774 7a8fc6c
Corrected dtype of probability array
Micky774 e20146d
Apply suggestions from code review
Micky774 0053299
Updated with remaining feedback
Micky774 3fc5733
Reverted for multi-output
Micky774 47a5e43
Changed `is_usable_for` to disable in euclidean case
Micky774 dcdeb4e
Merge branch 'main' into pwd_kncp
Micky774 67ce048
TST Improve implementations and test coverage
jjerphan 08b79d8
Merge branch 'main' into pwd_kncp
Micky774 bb17a45
Updated with feedback
Micky774 4d41a55
Merge pull request #9 from jjerphan/pwd_kncp
Micky774 ac0bb5d
Updated with `unique_labels` argument
Micky774 df7f8c1
Corrected `labels` dtype in private test on `ArgKminLabels` directly
Micky774 446e156
Moved casting to `compute` method
Micky774 e5c9890
Rename ArgKminLabels to ArgKminClassMode
jjerphan 52452d0
Use self._fit_method instead of self.algorithm
jjerphan caf8504
Remove the old comment
jjerphan f321fe5
Merge pull request #10 from jjerphan/pwd_kncp
Micky774 8a4b5a3
Fix typo in module import
ogrisel e48a5a4
FIX need to call check_is_fitted in the predict method
ogrisel 7536888
Implemented feedback and removed extraneous label mapping
Micky774 bb5a09b
Removed secondary validation
Micky774 15bf712
Added changelog
Micky774 95239d9
Merge branch 'main' into pwd_kncp
Micky774 025116a
Updated variable names per feedback
Micky774 1d40084
Merge branch 'main' into pwd_kncp
Micky774 9de740b
Removed extraneous sort
Micky774 11b3106
Update doc/whats_new/v1.3.rst
Micky774 3bc9e2b
Fixed inconsistency between strategies
Micky774 ca438ee
Altered strategy in response to benchmarks
Micky774 8f4b371
Fixed when validation occurs, avoiding accidental double validation
Micky774 a1370fd
Merge branch 'main' into pwd_kncp
Micky774 6551d86
Apply suggestions from code review
Micky774 f52468d
Updated with review feedback
Micky774 e403681
Update sklearn/neighbors/_classification.py
Micky774 25d142f
Updated .gitignore
Micky774 1c0d1e5
Merge branch 'main' into pwd_kncp
Micky774 6f4433d
Updated to reconcile new name
Micky774 d13793d
Qualify cdef methods with "noexcept nogil"
jjerphan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
208 changes: 208 additions & 0 deletions
208
sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx.tp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
{{py: | ||
|
||
implementation_specific_values = [ | ||
# Values are the following ones: | ||
# | ||
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE | ||
# | ||
# We also use the float64 dtype and C-type names as defined in | ||
# `sklearn.utils._typedefs` to maintain consistency. | ||
# | ||
('64', 'DTYPE_t', 'DTYPE'), | ||
('32', 'cnp.float32_t', 'np.float32') | ||
] | ||
|
||
}} | ||
|
||
from cython cimport floating, integral | ||
from cython.parallel cimport parallel, prange | ||
from libcpp.map cimport map as cpp_map, pair as cpp_pair | ||
from libc.stdlib cimport free | ||
|
||
cimport numpy as cnp | ||
|
||
cnp.import_array() | ||
|
||
from ...utils._typedefs cimport ITYPE_t, DTYPE_t | ||
from ...utils._typedefs import ITYPE, DTYPE | ||
import numpy as np | ||
from scipy.sparse import issparse | ||
from sklearn.utils.fixes import threadpool_limits | ||
|
||
cpdef enum WeightingStrategy: | ||
uniform = 0 | ||
# TODO: Implement the following options, most likely in | ||
# `weighted_histogram_mode` | ||
distance = 1 | ||
callable = 2 | ||
|
||
{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} | ||
from ._argkmin cimport ArgKmin{{name_suffix}} | ||
from ._datasets_pair cimport DatasetsPair{{name_suffix}} | ||
|
||
cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}): | ||
""" | ||
{{name_suffix}}bit implementation of ArgKminClassMode. | ||
""" | ||
cdef: | ||
const ITYPE_t[:] class_membership, | ||
const ITYPE_t[:] unique_labels | ||
DTYPE_t[:, :] class_scores | ||
cpp_map[ITYPE_t, ITYPE_t] labels_to_index | ||
WeightingStrategy weight_type | ||
|
||
@classmethod | ||
def compute( | ||
cls, | ||
X, | ||
Y, | ||
ITYPE_t k, | ||
weights, | ||
class_membership, | ||
unique_labels, | ||
str metric="euclidean", | ||
chunk_size=None, | ||
dict metric_kwargs=None, | ||
str strategy=None, | ||
): | ||
"""Compute the argkmin reduction with class_membership. | ||
|
||
This classmethod is responsible for introspecting the arguments | ||
values to dispatch to the most appropriate implementation of | ||
:class:`ArgKminClassMode{{name_suffix}}`. | ||
|
||
This allows decoupling the API entirely from the implementation details | ||
whilst maintaining RAII: all temporarily allocated datastructures necessary | ||
for the concrete implementation are therefore freed when this classmethod | ||
returns. | ||
|
||
No instance _must_ directly be created outside of this class method. | ||
""" | ||
# Use a generic implementation that handles most scipy | ||
# metrics by computing the distances between 2 vectors at a time. | ||
pda = ArgKminClassMode{{name_suffix}}( | ||
datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), | ||
k=k, | ||
chunk_size=chunk_size, | ||
strategy=strategy, | ||
weights=weights, | ||
class_membership=class_membership, | ||
unique_labels=unique_labels, | ||
) | ||
|
||
# Limit the number of threads in second level of nested parallelism for BLAS | ||
# to avoid threads over-subscription (in GEMM for instance). | ||
with threadpool_limits(limits=1, user_api="blas"): | ||
if pda.execute_in_parallel_on_Y: | ||
pda._parallel_on_Y() | ||
else: | ||
pda._parallel_on_X() | ||
|
||
return pda._finalize_results() | ||
|
||
def __init__( | ||
self, | ||
DatasetsPair{{name_suffix}} datasets_pair, | ||
const ITYPE_t[:] class_membership, | ||
const ITYPE_t[:] unique_labels, | ||
chunk_size=None, | ||
strategy=None, | ||
ITYPE_t k=1, | ||
weights=None, | ||
): | ||
super().__init__( | ||
datasets_pair=datasets_pair, | ||
chunk_size=chunk_size, | ||
strategy=strategy, | ||
k=k, | ||
) | ||
|
||
if weights == "uniform": | ||
self.weight_type = WeightingStrategy.uniform | ||
elif weights == "distance": | ||
self.weight_type = WeightingStrategy.distance | ||
else: | ||
self.weight_type = WeightingStrategy.callable | ||
self.class_membership = class_membership | ||
|
||
self.unique_labels = unique_labels | ||
|
||
cdef ITYPE_t idx, neighbor_class_idx | ||
# Map from set of unique labels to their indices in `class_scores` | ||
# Buffer used in building a histogram for one-pass weighted mode | ||
self.class_scores = np.zeros( | ||
(self.n_samples_X, unique_labels.shape[0]), dtype=DTYPE, | ||
) | ||
|
||
def _finalize_results(self): | ||
probabilities = np.asarray(self.class_scores) | ||
probabilities /= probabilities.sum(axis=1, keepdims=True) | ||
return probabilities | ||
|
||
cdef inline void weighted_histogram_mode( | ||
self, | ||
ITYPE_t sample_index, | ||
ITYPE_t* indices, | ||
DTYPE_t* distances, | ||
) noexcept nogil: | ||
cdef: | ||
ITYPE_t neighbor_idx, neighbor_class_idx, label_index, multi_output_index | ||
DTYPE_t score_incr = 1 | ||
# TODO: Implement other WeightingStrategy values | ||
bint use_distance_weighting = ( | ||
self.weight_type == WeightingStrategy.distance | ||
) | ||
|
||
# Iterate through the sample k-nearest neighbours | ||
for neighbor_rank in range(self.k): | ||
# Absolute indice of the neighbor_rank-th Nearest Neighbors | ||
# in range [0, n_samples_Y) | ||
# TODO: inspect if it worth permuting this condition | ||
# and the for-loop above for improved branching. | ||
if use_distance_weighting: | ||
score_incr = 1 / distances[neighbor_rank] | ||
neighbor_idx = indices[neighbor_rank] | ||
neighbor_class_idx = self.class_membership[neighbor_idx] | ||
self.class_scores[sample_index][neighbor_class_idx] += score_incr | ||
return | ||
|
||
cdef void _parallel_on_X_prange_iter_finalize( | ||
self, | ||
ITYPE_t thread_num, | ||
ITYPE_t X_start, | ||
ITYPE_t X_end, | ||
) noexcept nogil: | ||
cdef: | ||
ITYPE_t idx, sample_index | ||
for idx in range(X_end - X_start): | ||
# One-pass top-one weighted mode | ||
# Compute the absolute index in [0, n_samples_X) | ||
sample_index = X_start + idx | ||
self.weighted_histogram_mode( | ||
sample_index, | ||
&self.heaps_indices_chunks[thread_num][idx * self.k], | ||
&self.heaps_r_distances_chunks[thread_num][idx * self.k], | ||
) | ||
return | ||
|
||
cdef void _parallel_on_Y_finalize( | ||
self, | ||
) noexcept nogil: | ||
cdef: | ||
ITYPE_t sample_index, thread_num | ||
|
||
with nogil, parallel(num_threads=self.chunks_n_threads): | ||
# Deallocating temporary datastructures | ||
for thread_num in prange(self.chunks_n_threads, schedule='static'): | ||
free(self.heaps_r_distances_chunks[thread_num]) | ||
free(self.heaps_indices_chunks[thread_num]) | ||
|
||
for sample_index in prange(self.n_samples_X, schedule='static'): | ||
self.weighted_histogram_mode( | ||
sample_index, | ||
&self.argkmin_indices[sample_index][0], | ||
&self.argkmin_distances[sample_index][0], | ||
) | ||
return | ||
|
||
{{endfor}} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.