Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit bb261bf

Browse filesBrowse files
EmilyXinyiogrisellucyleeow
authored
Add array API support for _weighted_percentile (#29431)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Lucy Liu <jliu176@gmail.com>
1 parent a6efcaf commit bb261bf
Copy full SHA for bb261bf

File tree

2 files changed

+170
-58
lines changed
Filter options

2 files changed

+170
-58
lines changed

‎sklearn/utils/stats.py

Copy file name to clipboard
+57-42Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Authors: The scikit-learn developers
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
import numpy as np
4+
from ..utils._array_api import (
5+
_find_matching_floating_dtype,
6+
get_namespace_and_device,
7+
)
58

6-
from .extmath import stable_cumsum
79

8-
9-
def _weighted_percentile(array, sample_weight, percentile_rank=50):
10+
def _weighted_percentile(array, sample_weight, percentile_rank=50, xp=None):
1011
"""Compute the weighted percentile with method 'inverted_cdf'.
1112
1213
When the percentile lies between two data points of `array`, the function returns
@@ -37,72 +38,86 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50):
3738
The probability level of the percentile to compute, in percent. Must be between
3839
0 and 100.
3940
41+
xp : array_namespace, default=None
42+
The standard-compatible namespace for `array`. Default: infer.
43+
4044
Returns
4145
-------
42-
percentile : int if `array` 1D, ndarray if `array` 2D
46+
percentile : scalar or 0D array if `array` 1D (or 0D), array if `array` 2D
4347
Weighted percentile at the requested probability level.
4448
"""
49+
xp, _, device = get_namespace_and_device(array)
50+
# `sample_weight` should follow `array` for dtypes
51+
floating_dtype = _find_matching_floating_dtype(array, xp=xp)
52+
array = xp.asarray(array, dtype=floating_dtype, device=device)
53+
sample_weight = xp.asarray(sample_weight, dtype=floating_dtype, device=device)
54+
4555
n_dim = array.ndim
4656
if n_dim == 0:
47-
return array[()]
57+
return array
4858
if array.ndim == 1:
49-
array = array.reshape((-1, 1))
59+
array = xp.reshape(array, (-1, 1))
5060
# When sample_weight 1D, repeat for each array.shape[1]
5161
if array.shape != sample_weight.shape and array.shape[0] == sample_weight.shape[0]:
52-
sample_weight = np.tile(sample_weight, (array.shape[1], 1)).T
53-
62+
sample_weight = xp.tile(sample_weight, (array.shape[1], 1)).T
5463
# Sort `array` and `sample_weight` along axis=0:
55-
sorted_idx = np.argsort(array, axis=0)
56-
sorted_weights = np.take_along_axis(sample_weight, sorted_idx, axis=0)
64+
sorted_idx = xp.argsort(array, axis=0)
65+
sorted_weights = xp.take_along_axis(sample_weight, sorted_idx, axis=0)
5766

58-
# Set NaN values in `sample_weight` to 0. We only perform this operation if NaN
59-
# values are present at all to avoid temporary allocations of size `(n_samples,
60-
# n_features)`. If NaN values were present, they would sort to the end (which we can
61-
# observe from `sorted_idx`).
67+
# Set NaN values in `sample_weight` to 0. Only perform this operation if NaN
68+
# values present to avoid temporary allocations of size `(n_samples, n_features)`.
6269
n_features = array.shape[1]
63-
largest_value_per_column = array[sorted_idx[-1, ...], np.arange(n_features)]
64-
if np.isnan(largest_value_per_column).any():
65-
sorted_nan_mask = np.take_along_axis(np.isnan(array), sorted_idx, axis=0)
70+
largest_value_per_column = array[
71+
sorted_idx[-1, ...], xp.arange(n_features, device=device)
72+
]
73+
# NaN values get sorted to end (largest value)
74+
if xp.any(xp.isnan(largest_value_per_column)):
75+
sorted_nan_mask = xp.take_along_axis(xp.isnan(array), sorted_idx, axis=0)
6676
sorted_weights[sorted_nan_mask] = 0
6777

6878
# Compute the weighted cumulative distribution function (CDF) based on
69-
# sample_weight and scale percentile_rank along it:
70-
weight_cdf = stable_cumsum(sorted_weights, axis=0)
71-
adjusted_percentile_rank = percentile_rank / 100 * weight_cdf[-1]
72-
73-
# For percentile_rank=0, ignore leading observations with sample_weight=0; see
74-
# PR #20528:
79+
# `sample_weight` and scale `percentile_rank` along it.
80+
#
81+
# Note: we call `xp.cumulative_sum` on the transposed `sorted_weights` to
82+
# ensure that the result is of shape `(n_features, n_samples)` so
83+
# `xp.searchsorted` calls take contiguous inputs as a result (for
84+
# performance reasons).
85+
weight_cdf = xp.cumulative_sum(sorted_weights.T, axis=1)
86+
adjusted_percentile_rank = percentile_rank / 100 * weight_cdf[..., -1]
87+
88+
# Ignore leading `sample_weight=0` observations when `percentile_rank=0` (#20528)
7589
mask = adjusted_percentile_rank == 0
76-
adjusted_percentile_rank[mask] = np.nextafter(
90+
adjusted_percentile_rank[mask] = xp.nextafter(
7791
adjusted_percentile_rank[mask], adjusted_percentile_rank[mask] + 1
7892
)
79-
80-
# Find index (i) of `adjusted_percentile` in `weight_cdf`,
81-
# such that weight_cdf[i-1] < percentile <= weight_cdf[i]
82-
percentile_idx = np.array(
93+
# For each feature with index j, find sample index i of the scalar value
94+
# `adjusted_percentile_rank[j]` in 1D array `weight_cdf[j]`, such that:
95+
# weight_cdf[j, i-1] < adjusted_percentile_rank[j] <= weight_cdf[j, i].
96+
percentile_indices = xp.asarray(
8397
[
84-
np.searchsorted(weight_cdf[:, i], adjusted_percentile_rank[i])
85-
for i in range(weight_cdf.shape[1])
86-
]
98+
xp.searchsorted(
99+
weight_cdf[feature_idx, ...], adjusted_percentile_rank[feature_idx]
100+
)
101+
for feature_idx in range(weight_cdf.shape[0])
102+
],
103+
device=device,
87104
)
88-
89-
# In rare cases, percentile_idx equals to sorted_idx.shape[0]:
105+
# In rare cases, `percentile_indices` equals to `sorted_idx.shape[0]`
90106
max_idx = sorted_idx.shape[0] - 1
91-
percentile_idx = np.apply_along_axis(
92-
lambda x: np.clip(x, 0, max_idx), axis=0, arr=percentile_idx
93-
)
107+
percentile_indices = xp.clip(percentile_indices, 0, max_idx)
108+
109+
col_indices = xp.arange(array.shape[1], device=device)
110+
percentile_in_sorted = sorted_idx[percentile_indices, col_indices]
94111

95-
col_indices = np.arange(array.shape[1])
96-
percentile_in_sorted = sorted_idx[percentile_idx, col_indices]
97112
result = array[percentile_in_sorted, col_indices]
98113

99114
return result[0] if n_dim == 1 else result
100115

101116

102117
# TODO: refactor to do the symmetrisation inside _weighted_percentile to avoid
103118
# sorting the input array twice.
104-
def _averaged_weighted_percentile(array, sample_weight, percentile_rank=50):
119+
def _averaged_weighted_percentile(array, sample_weight, percentile_rank=50, xp=None):
105120
return (
106-
_weighted_percentile(array, sample_weight, percentile_rank)
107-
- _weighted_percentile(-array, sample_weight, 100 - percentile_rank)
121+
_weighted_percentile(array, sample_weight, percentile_rank, xp=xp)
122+
- _weighted_percentile(-array, sample_weight, 100 - percentile_rank, xp=xp)
108123
) / 2

‎sklearn/utils/tests/test_stats.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_stats.py
+113-16Lines changed: 113 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
from numpy.testing import assert_allclose, assert_array_equal
44
from pytest import approx
55

6+
from sklearn._config import config_context
7+
from sklearn.utils._array_api import (
8+
_convert_to_numpy,
9+
get_namespace,
10+
yield_namespace_device_dtype_combinations,
11+
)
12+
from sklearn.utils._array_api import device as array_device
13+
from sklearn.utils.estimator_checks import _array_api_for_tests
614
from sklearn.utils.fixes import np_version, parse_version
715
from sklearn.utils.stats import _averaged_weighted_percentile, _weighted_percentile
816

@@ -39,6 +47,7 @@ def test_averaged_and_weighted_percentile():
3947

4048

4149
def test_weighted_percentile():
50+
"""Check `weighted_percentile` on artificial data with obvious median."""
4251
y = np.empty(102, dtype=np.float64)
4352
y[:50] = 0
4453
y[-51:] = 2
@@ -51,15 +60,16 @@ def test_weighted_percentile():
5160

5261

5362
def test_weighted_percentile_equal():
63+
"""Check `weighted_percentile` with all weights equal to 1."""
5464
y = np.empty(102, dtype=np.float64)
5565
y.fill(0.0)
5666
sw = np.ones(102, dtype=np.float64)
57-
sw[-1] = 0.0
58-
value = _weighted_percentile(y, sw, 50)
59-
assert value == 0
67+
score = _weighted_percentile(y, sw, 50)
68+
assert approx(score) == 0
6069

6170

6271
def test_weighted_percentile_zero_weight():
72+
"""Check `weighted_percentile` with all weights equal to 0."""
6373
y = np.empty(102, dtype=np.float64)
6474
y.fill(1.0)
6575
sw = np.ones(102, dtype=np.float64)
@@ -69,6 +79,11 @@ def test_weighted_percentile_zero_weight():
6979

7080

7181
def test_weighted_percentile_zero_weight_zero_percentile():
82+
"""Check `weighted_percentile(percentile_rank=0)` behaves correctly.
83+
84+
Ensures that (leading)zero-weight observations ignored when `percentile_rank=0`.
85+
See #20528 for details.
86+
"""
7287
y = np.array([0, 1, 2, 3, 4, 5])
7388
sw = np.array([0, 0, 1, 1, 1, 0])
7489
value = _weighted_percentile(y, sw, 0)
@@ -82,18 +97,18 @@ def test_weighted_percentile_zero_weight_zero_percentile():
8297

8398

8499
def test_weighted_median_equal_weights():
85-
# Checks that `_weighted_percentile` and `np.median` (both at probability level=0.5
86-
# and with `sample_weights` being all 1s) return the same percentiles if the number
87-
# of the samples in the data is odd. In this special case, `_weighted_percentile`
88-
# always falls on a precise value (not on the next lower value) and is thus equal to
89-
# `np.median`.
90-
# As discussed in #17370, a similar check with an even number of samples does not
91-
# consistently hold, since then the lower of two percentiles might be selected,
92-
# while the median might lie in between.
100+
"""Checks `_weighted_percentile(percentile_rank=50)` is the same as `np.median`.
101+
102+
`sample_weights` are all 1s and the number of samples is odd.
103+
When number of samples is odd, `_weighted_percentile` always falls on a single
104+
observation (not between 2 values, in which case the lower value would be taken)
105+
and is thus equal to `np.median`.
106+
For an even number of samples, this check will not always hold as (note that
107+
for some other percentile methods it will always hold). See #17370 for details.
108+
"""
93109
rng = np.random.RandomState(0)
94110
x = rng.randint(10, size=11)
95111
weights = np.ones(x.shape)
96-
97112
median = np.median(x)
98113
w_median = _weighted_percentile(x, weights)
99114
assert median == approx(w_median)
@@ -106,10 +121,8 @@ def test_weighted_median_integer_weights():
106121
x = rng.randint(20, size=10)
107122
weights = rng.choice(5, size=10)
108123
x_manual = np.repeat(x, weights)
109-
110124
median = np.median(x_manual)
111125
w_median = _weighted_percentile(x, weights)
112-
113126
assert median == approx(w_median)
114127

115128

@@ -125,8 +138,7 @@ def test_weighted_percentile_2d():
125138
w_median = _weighted_percentile(x_2d, w1)
126139
p_axis_0 = [_weighted_percentile(x_2d[:, i], w1) for i in range(x_2d.shape[1])]
127140
assert_allclose(w_median, p_axis_0)
128-
129-
# Check when array and sample_weight boht 2D
141+
# Check when array and sample_weight both 2D
130142
w2 = rng.choice(5, size=10)
131143
w_2d = np.vstack((w1, w2)).T
132144

@@ -137,6 +149,91 @@ def test_weighted_percentile_2d():
137149
assert_allclose(w_median, p_axis_0)
138150

139151

152+
@pytest.mark.parametrize(
153+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
154+
)
155+
@pytest.mark.parametrize(
156+
"data, weights, percentile",
157+
[
158+
# NumPy scalars input (handled as 0D arrays on array API)
159+
(np.float32(42), np.int32(1), 50),
160+
# Random 1D array, constant weights
161+
(lambda rng: rng.rand(50), np.ones(50).astype(np.int32), 50),
162+
# Random 2D array and random 1D weights
163+
(lambda rng: rng.rand(50, 3), lambda rng: rng.rand(50).astype(np.float32), 75),
164+
# Random 2D array and random 2D weights
165+
(
166+
lambda rng: rng.rand(20, 3),
167+
lambda rng: rng.rand(20, 3).astype(np.float32),
168+
25,
169+
),
170+
# zero-weights and `rank_percentile=0` (#20528) (`sample_weight` dtype: int64)
171+
(np.array([0, 1, 2, 3, 4, 5]), np.array([0, 0, 1, 1, 1, 0]), 0),
172+
# np.nan's in data and some zero-weights (`sample_weight` dtype: int64)
173+
(np.array([np.nan, np.nan, 0, 3, 4, 5]), np.array([0, 1, 1, 1, 1, 0]), 0),
174+
# `sample_weight` dtype: int32
175+
(
176+
np.array([0, 1, 2, 3, 4, 5]),
177+
np.array([0, 1, 1, 1, 1, 0], dtype=np.int32),
178+
25,
179+
),
180+
],
181+
)
182+
def test_weighted_percentile_array_api_consistency(
183+
global_random_seed, array_namespace, device, dtype_name, data, weights, percentile
184+
):
185+
"""Check `_weighted_percentile` gives consistent results with array API."""
186+
if array_namespace == "array_api_strict":
187+
try:
188+
import array_api_strict
189+
except ImportError:
190+
pass
191+
else:
192+
if device == array_api_strict.Device("device1"):
193+
# See https://github.com/data-apis/array-api-strict/issues/134
194+
pytest.xfail(
195+
"array_api_strict has bug when indexing with tuple of arrays "
196+
"on non-'CPU_DEVICE' devices."
197+
)
198+
199+
xp = _array_api_for_tests(array_namespace, device)
200+
201+
# Skip test for percentile=0 edge case (#20528) on namespace/device where
202+
# xp.nextafter is broken. This is the case for torch with MPS device:
203+
# https://github.com/pytorch/pytorch/issues/150027
204+
zero = xp.zeros(1, device=device)
205+
one = xp.ones(1, device=device)
206+
if percentile == 0 and xp.all(xp.nextafter(zero, one) == zero):
207+
pytest.xfail(f"xp.nextafter is broken on {device}")
208+
209+
rng = np.random.RandomState(global_random_seed)
210+
X_np = data(rng) if callable(data) else data
211+
weights_np = weights(rng) if callable(weights) else weights
212+
# Ensure `data` of correct dtype
213+
X_np = X_np.astype(dtype_name)
214+
215+
result_np = _weighted_percentile(X_np, weights_np, percentile)
216+
# Convert to Array API arrays
217+
X_xp = xp.asarray(X_np, device=device)
218+
weights_xp = xp.asarray(weights_np, device=device)
219+
220+
with config_context(array_api_dispatch=True):
221+
result_xp = _weighted_percentile(X_xp, weights_xp, percentile)
222+
assert array_device(result_xp) == array_device(X_xp)
223+
assert get_namespace(result_xp)[0] == get_namespace(X_xp)[0]
224+
result_xp_np = _convert_to_numpy(result_xp, xp=xp)
225+
226+
assert result_xp_np.dtype == result_np.dtype
227+
assert result_xp_np.shape == result_np.shape
228+
assert_allclose(result_np, result_xp_np)
229+
230+
# Check dtype correct (`sample_weight` should follow `array`)
231+
if dtype_name == "float32":
232+
assert result_xp_np.dtype == result_np.dtype == np.float32
233+
else:
234+
assert result_xp_np.dtype == np.float64
235+
236+
140237
@pytest.mark.parametrize("sample_weight_ndim", [1, 2])
141238
def test_weighted_percentile_nan_filtered(sample_weight_ndim):
142239
"""Test that calling _weighted_percentile on an array with nan values returns

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.