Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 0c4252c

Browse filesBrowse files
committed
Revert "ENH Perform KNN imputation without O(n^2) memory cost"
accidentally pushed to master This reverts commit ae9eaf8.
1 parent ae9eaf8 commit 0c4252c
Copy full SHA for 0c4252c

File tree

Expand file treeCollapse file tree

4 files changed

+61
-94
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+61
-94
lines changed

‎sklearn/impute/_knn.py

Copy file name to clipboardExpand all lines: sklearn/impute/_knn.py
+51-61Lines changed: 51 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ._base import _BaseImputer
88
from ..utils.validation import FLOAT_DTYPES
9-
from ..metrics import pairwise_distances_chunked
9+
from ..metrics import pairwise_distances
1010
from ..metrics.pairwise import _NAN_METRICS
1111
from ..neighbors._base import _get_weights
1212
from ..neighbors._base import _check_weights
@@ -217,81 +217,71 @@ def transform(self, X):
217217

218218
mask = _get_mask(X, self.missing_values)
219219
mask_fit_X = self._mask_fit_X
220-
valid_mask = ~np.all(mask_fit_X, axis=0)
221220

221+
# Removes columns where the training data is all nan
222222
if not np.any(mask):
223-
# No missing values in X
224-
# Remove columns where the training data is all nan
223+
valid_mask = ~np.all(mask_fit_X, axis=0)
225224
return X[:, valid_mask]
226225

227226
row_missing_idx = np.flatnonzero(mask.any(axis=1))
228227

229-
non_missing_fix_X = np.logical_not(mask_fit_X)
228+
# Pairwise distances between receivers and fitted samples
229+
dist = pairwise_distances(X[row_missing_idx, :], self._fit_X,
230+
metric=self.metric,
231+
missing_values=self.missing_values,
232+
force_all_finite=force_all_finite)
230233

231234
# Maps from indices from X to indices in dist matrix
232235
dist_idx_map = np.zeros(X.shape[0], dtype=np.int)
233236
dist_idx_map[row_missing_idx] = np.arange(row_missing_idx.shape[0])
234237

235-
def process_chunk(dist_chunk, start):
236-
row_missing_chunk = row_missing_idx[start:start + len(dist_chunk)]
238+
non_missing_fix_X = np.logical_not(mask_fit_X)
237239

238-
# Find and impute missing by column
239-
for col in range(X.shape[1]):
240-
if not valid_mask[col]:
241-
# column was all missing during training
242-
continue
240+
# Find and impute missing
241+
valid_idx = []
242+
for col in range(X.shape[1]):
243243

244-
col_mask = mask[row_missing_chunk, col]
245-
if not np.any(col_mask):
246-
# column has no missing values
247-
continue
244+
potential_donors_idx = np.flatnonzero(non_missing_fix_X[:, col])
245+
246+
# column was all missing during training
247+
if len(potential_donors_idx) == 0:
248+
continue
249+
250+
# column has no missing values
251+
if not np.any(mask[:, col]):
252+
valid_idx.append(col)
253+
continue
248254

249-
potential_donors_idx, = np.nonzero(non_missing_fix_X[:, col])
255+
valid_idx.append(col)
250256

251-
# receivers_idx are indices in X
252-
receivers_idx = row_missing_chunk[np.flatnonzero(col_mask)]
257+
receivers_idx = np.flatnonzero(mask[:, col])
253258

254-
# distances for samples that needed imputation for column
255-
dist_subset = (dist_chunk[dist_idx_map[receivers_idx] - start]
259+
# distances for samples that needed imputation for column
260+
dist_subset = (dist[dist_idx_map[receivers_idx]]
261+
[:, potential_donors_idx])
262+
263+
# receivers with all nan distances impute with mean
264+
all_nan_dist_mask = np.isnan(dist_subset).all(axis=1)
265+
all_nan_receivers_idx = receivers_idx[all_nan_dist_mask]
266+
267+
if all_nan_receivers_idx.size:
268+
col_mean = np.ma.array(self._fit_X[:, col],
269+
mask=mask_fit_X[:, col]).mean()
270+
X[all_nan_receivers_idx, col] = col_mean
271+
272+
if len(all_nan_receivers_idx) == len(receivers_idx):
273+
# all receivers imputed with mean
274+
continue
275+
276+
# receivers with at least one defined distance
277+
receivers_idx = receivers_idx[~all_nan_dist_mask]
278+
dist_subset = (dist[dist_idx_map[receivers_idx]]
256279
[:, potential_donors_idx])
257280

258-
# receivers with all nan distances impute with mean
259-
all_nan_dist_mask = np.isnan(dist_subset).all(axis=1)
260-
all_nan_receivers_idx = receivers_idx[all_nan_dist_mask]
261-
262-
if all_nan_receivers_idx.size:
263-
col_mean = np.ma.array(self._fit_X[:, col],
264-
mask=mask_fit_X[:, col]).mean()
265-
X[all_nan_receivers_idx, col] = col_mean
266-
267-
if len(all_nan_receivers_idx) == len(receivers_idx):
268-
# all receivers imputed with mean
269-
continue
270-
271-
# receivers with at least one defined distance
272-
receivers_idx = receivers_idx[~all_nan_dist_mask]
273-
dist_subset = (dist_chunk[dist_idx_map[receivers_idx]
274-
- start]
275-
[:, potential_donors_idx])
276-
277-
n_neighbors = min(self.n_neighbors, len(potential_donors_idx))
278-
value = self._calc_impute(
279-
dist_subset,
280-
n_neighbors,
281-
self._fit_X[potential_donors_idx, col],
282-
mask_fit_X[potential_donors_idx, col])
283-
X[receivers_idx, col] = value
284-
285-
# process in fixed-memory chunks
286-
gen = pairwise_distances_chunked(
287-
X[row_missing_idx, :],
288-
self._fit_X,
289-
metric=self.metric,
290-
missing_values=self.missing_values,
291-
force_all_finite=force_all_finite,
292-
reduce_func=process_chunk)
293-
for chunk in gen:
294-
# process_chunk modifies X in place. No return value.
295-
pass
296-
297-
return super()._concatenate_indicator(X[:, valid_mask], X_indicator)
281+
n_neighbors = min(self.n_neighbors, len(potential_donors_idx))
282+
value = self._calc_impute(dist_subset, n_neighbors,
283+
self._fit_X[potential_donors_idx, col],
284+
mask_fit_X[potential_donors_idx, col])
285+
X[receivers_idx, col] = value
286+
287+
return super()._concatenate_indicator(X[:, valid_idx], X_indicator)

‎sklearn/impute/tests/test_knn.py

Copy file name to clipboardExpand all lines: sklearn/impute/tests/test_knn.py
+8-16Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
from sklearn import config_context
54
from sklearn.impute import KNNImputer
65
from sklearn.metrics.pairwise import nan_euclidean_distances
76
from sklearn.metrics.pairwise import pairwise_distances
@@ -523,9 +522,8 @@ def custom_callable(x, y, missing_values=np.nan, squared=False):
523522
assert_allclose(imputer.fit_transform(X), X_imputed)
524523

525524

526-
@pytest.mark.parametrize("working_memory", [None, 0])
527525
@pytest.mark.parametrize("na", [-1, np.nan])
528-
def test_knn_imputer_with_simple_example(na, working_memory):
526+
def test_knn_imputer_with_simple_example(na):
529527

530528
X = np.array([
531529
[0, na, 0, na],
@@ -555,9 +553,8 @@ def test_knn_imputer_with_simple_example(na, working_memory):
555553
[r7c0, 7, 7, 7]
556554
])
557555

558-
with config_context(working_memory=working_memory):
559-
imputer_comp = KNNImputer(missing_values=na)
560-
assert_allclose(imputer_comp.fit_transform(X), X_imputed)
556+
imputer_comp = KNNImputer(missing_values=na)
557+
assert_allclose(imputer_comp.fit_transform(X), X_imputed)
561558

562559

563560
@pytest.mark.parametrize("na", [-1, np.nan])
@@ -601,10 +598,8 @@ def test_knn_imputer_drops_all_nan_features(na):
601598
assert_allclose(knn.transform(X2), X2_expected)
602599

603600

604-
@pytest.mark.parametrize("working_memory", [None, 0])
605601
@pytest.mark.parametrize("na", [-1, np.nan])
606-
def test_knn_imputer_distance_weighted_not_enough_neighbors(na,
607-
working_memory):
602+
def test_knn_imputer_distance_weighted_not_enough_neighbors(na):
608603
X = np.array([
609604
[3, na],
610605
[2, na],
@@ -631,14 +626,11 @@ def test_knn_imputer_distance_weighted_not_enough_neighbors(na,
631626
[X_50, 5]
632627
])
633628

634-
with config_context(working_memory=working_memory):
635-
knn_3 = KNNImputer(missing_values=na, n_neighbors=3,
636-
weights='distance')
637-
assert_allclose(knn_3.fit_transform(X), X_expected)
629+
knn_3 = KNNImputer(missing_values=na, n_neighbors=3, weights='distance')
630+
assert_allclose(knn_3.fit_transform(X), X_expected)
638631

639-
knn_4 = KNNImputer(missing_values=na, n_neighbors=4,
640-
weights='distance')
641-
assert_allclose(knn_4.fit_transform(X), X_expected)
632+
knn_4 = KNNImputer(missing_values=na, n_neighbors=4, weights='distance')
633+
assert_allclose(knn_4.fit_transform(X), X_expected)
642634

643635

644636
@pytest.mark.parametrize("na, allow_nan", [(-1, False), (np.nan, True)])

‎sklearn/metrics/pairwise.py

Copy file name to clipboardExpand all lines: sklearn/metrics/pairwise.py
+2-4Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,8 +1408,6 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds):
14081408
def _check_chunk_size(reduced, chunk_size):
14091409
"""Checks chunk is a sequence of expected size or a tuple of same
14101410
"""
1411-
if reduced is None:
1412-
return
14131411
is_tuple = isinstance(reduced, tuple)
14141412
if not is_tuple:
14151413
reduced = (reduced,)
@@ -1470,8 +1468,8 @@ def pairwise_distances_chunked(X, Y=None, reduce_func=None,
14701468
reducing it to needed values. ``reduce_func(D_chunk, start)``
14711469
is called repeatedly, where ``D_chunk`` is a contiguous vertical
14721470
slice of the pairwise distance matrix, starting at row ``start``.
1473-
It should return one of: None; an array, a list, or a sparse matrix
1474-
of length ``D_chunk.shape[0]``; or a tuple of such objects.
1471+
It should return an array, a list, or a sparse matrix of length
1472+
``D_chunk.shape[0]``, or a tuple of such objects.
14751473
14761474
If None, pairwise_distances_chunked returns a generator of vertical
14771475
chunks of the distance matrix.

‎sklearn/metrics/tests/test_pairwise.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_pairwise.py
-13Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -483,19 +483,6 @@ def test_pairwise_distances_chunked_reduce():
483483
assert_allclose(np.vstack(S_chunks), S, atol=1e-7)
484484

485485

486-
def test_pairwise_distances_chunked_reduce_none():
487-
# check that the reduce func is allowed to return None
488-
rng = np.random.RandomState(0)
489-
X = rng.random_sample((10, 4))
490-
S_chunks = pairwise_distances_chunked(X, None,
491-
reduce_func=lambda dist, start: None,
492-
working_memory=2 ** -16)
493-
assert isinstance(S_chunks, GeneratorType)
494-
S_chunks = list(S_chunks)
495-
assert len(S_chunks) > 1
496-
assert all(chunk is None for chunk in S_chunks)
497-
498-
499486
@pytest.mark.parametrize('good_reduce', [
500487
lambda D, start: list(D),
501488
lambda D, start: np.array(D),

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.