3
3
from numpy .testing import assert_allclose , assert_array_equal
4
4
from pytest import approx
5
5
6
+ from sklearn ._config import config_context
7
+ from sklearn .utils ._array_api import (
8
+ _convert_to_numpy ,
9
+ get_namespace ,
10
+ yield_namespace_device_dtype_combinations ,
11
+ )
12
+ from sklearn .utils ._array_api import device as array_device
13
+ from sklearn .utils .estimator_checks import _array_api_for_tests
6
14
from sklearn .utils .fixes import np_version , parse_version
7
15
from sklearn .utils .stats import _averaged_weighted_percentile , _weighted_percentile
8
16
@@ -39,6 +47,7 @@ def test_averaged_and_weighted_percentile():
39
47
40
48
41
49
def test_weighted_percentile ():
50
+ """Check `weighted_percentile` on artificial data with obvious median."""
42
51
y = np .empty (102 , dtype = np .float64 )
43
52
y [:50 ] = 0
44
53
y [- 51 :] = 2
@@ -51,15 +60,16 @@ def test_weighted_percentile():
51
60
52
61
53
62
def test_weighted_percentile_equal ():
63
+ """Check `weighted_percentile` with all weights equal to 1."""
54
64
y = np .empty (102 , dtype = np .float64 )
55
65
y .fill (0.0 )
56
66
sw = np .ones (102 , dtype = np .float64 )
57
- sw [- 1 ] = 0.0
58
- value = _weighted_percentile (y , sw , 50 )
59
- assert value == 0
67
+ score = _weighted_percentile (y , sw , 50 )
68
+ assert approx (score ) == 0
60
69
61
70
62
71
def test_weighted_percentile_zero_weight ():
72
+ """Check `weighted_percentile` with all weights equal to 0."""
63
73
y = np .empty (102 , dtype = np .float64 )
64
74
y .fill (1.0 )
65
75
sw = np .ones (102 , dtype = np .float64 )
@@ -69,6 +79,11 @@ def test_weighted_percentile_zero_weight():
69
79
70
80
71
81
def test_weighted_percentile_zero_weight_zero_percentile ():
82
+ """Check `weighted_percentile(percentile_rank=0)` behaves correctly.
83
+
84
+ Ensures that (leading)zero-weight observations ignored when `percentile_rank=0`.
85
+ See #20528 for details.
86
+ """
72
87
y = np .array ([0 , 1 , 2 , 3 , 4 , 5 ])
73
88
sw = np .array ([0 , 0 , 1 , 1 , 1 , 0 ])
74
89
value = _weighted_percentile (y , sw , 0 )
@@ -82,18 +97,18 @@ def test_weighted_percentile_zero_weight_zero_percentile():
82
97
83
98
84
99
def test_weighted_median_equal_weights ():
85
- # Checks that `_weighted_percentile` and `np.median` (both at probability level=0.5
86
- # and with `sample_weights` being all 1s) return the same percentiles if the number
87
- # of the samples in the data is odd. In this special case, `_weighted_percentile`
88
- # always falls on a precise value (not on the next lower value) and is thus equal to
89
- # `np.median`.
90
- # As discussed in #17370, a similar check with an even number of samples does not
91
- # consistently hold, since then the lower of two percentiles might be selected,
92
- # while the median might lie in between.
100
+ """Checks `_weighted_percentile(percentile_rank=50)` is the same as `np.median`.
101
+
102
+ `sample_weights` are all 1s and the number of samples is odd.
103
+ When number of samples is odd, `_weighted_percentile` always falls on a single
104
+ observation (not between 2 values, in which case the lower value would be taken)
105
+ and is thus equal to `np.median`.
106
+ For an even number of samples, this check will not always hold as (note that
107
+ for some other percentile methods it will always hold). See #17370 for details.
108
+ """
93
109
rng = np .random .RandomState (0 )
94
110
x = rng .randint (10 , size = 11 )
95
111
weights = np .ones (x .shape )
96
-
97
112
median = np .median (x )
98
113
w_median = _weighted_percentile (x , weights )
99
114
assert median == approx (w_median )
@@ -106,10 +121,8 @@ def test_weighted_median_integer_weights():
106
121
x = rng .randint (20 , size = 10 )
107
122
weights = rng .choice (5 , size = 10 )
108
123
x_manual = np .repeat (x , weights )
109
-
110
124
median = np .median (x_manual )
111
125
w_median = _weighted_percentile (x , weights )
112
-
113
126
assert median == approx (w_median )
114
127
115
128
@@ -125,8 +138,7 @@ def test_weighted_percentile_2d():
125
138
w_median = _weighted_percentile (x_2d , w1 )
126
139
p_axis_0 = [_weighted_percentile (x_2d [:, i ], w1 ) for i in range (x_2d .shape [1 ])]
127
140
assert_allclose (w_median , p_axis_0 )
128
-
129
- # Check when array and sample_weight boht 2D
141
+ # Check when array and sample_weight both 2D
130
142
w2 = rng .choice (5 , size = 10 )
131
143
w_2d = np .vstack ((w1 , w2 )).T
132
144
@@ -137,6 +149,91 @@ def test_weighted_percentile_2d():
137
149
assert_allclose (w_median , p_axis_0 )
138
150
139
151
152
+ @pytest .mark .parametrize (
153
+ "array_namespace, device, dtype_name" , yield_namespace_device_dtype_combinations ()
154
+ )
155
+ @pytest .mark .parametrize (
156
+ "data, weights, percentile" ,
157
+ [
158
+ # NumPy scalars input (handled as 0D arrays on array API)
159
+ (np .float32 (42 ), np .int32 (1 ), 50 ),
160
+ # Random 1D array, constant weights
161
+ (lambda rng : rng .rand (50 ), np .ones (50 ).astype (np .int32 ), 50 ),
162
+ # Random 2D array and random 1D weights
163
+ (lambda rng : rng .rand (50 , 3 ), lambda rng : rng .rand (50 ).astype (np .float32 ), 75 ),
164
+ # Random 2D array and random 2D weights
165
+ (
166
+ lambda rng : rng .rand (20 , 3 ),
167
+ lambda rng : rng .rand (20 , 3 ).astype (np .float32 ),
168
+ 25 ,
169
+ ),
170
+ # zero-weights and `rank_percentile=0` (#20528) (`sample_weight` dtype: int64)
171
+ (np .array ([0 , 1 , 2 , 3 , 4 , 5 ]), np .array ([0 , 0 , 1 , 1 , 1 , 0 ]), 0 ),
172
+ # np.nan's in data and some zero-weights (`sample_weight` dtype: int64)
173
+ (np .array ([np .nan , np .nan , 0 , 3 , 4 , 5 ]), np .array ([0 , 1 , 1 , 1 , 1 , 0 ]), 0 ),
174
+ # `sample_weight` dtype: int32
175
+ (
176
+ np .array ([0 , 1 , 2 , 3 , 4 , 5 ]),
177
+ np .array ([0 , 1 , 1 , 1 , 1 , 0 ], dtype = np .int32 ),
178
+ 25 ,
179
+ ),
180
+ ],
181
+ )
182
+ def test_weighted_percentile_array_api_consistency (
183
+ global_random_seed , array_namespace , device , dtype_name , data , weights , percentile
184
+ ):
185
+ """Check `_weighted_percentile` gives consistent results with array API."""
186
+ if array_namespace == "array_api_strict" :
187
+ try :
188
+ import array_api_strict
189
+ except ImportError :
190
+ pass
191
+ else :
192
+ if device == array_api_strict .Device ("device1" ):
193
+ # See https://github.com/data-apis/array-api-strict/issues/134
194
+ pytest .xfail (
195
+ "array_api_strict has bug when indexing with tuple of arrays "
196
+ "on non-'CPU_DEVICE' devices."
197
+ )
198
+
199
+ xp = _array_api_for_tests (array_namespace , device )
200
+
201
+ # Skip test for percentile=0 edge case (#20528) on namespace/device where
202
+ # xp.nextafter is broken. This is the case for torch with MPS device:
203
+ # https://github.com/pytorch/pytorch/issues/150027
204
+ zero = xp .zeros (1 , device = device )
205
+ one = xp .ones (1 , device = device )
206
+ if percentile == 0 and xp .all (xp .nextafter (zero , one ) == zero ):
207
+ pytest .xfail (f"xp.nextafter is broken on { device } " )
208
+
209
+ rng = np .random .RandomState (global_random_seed )
210
+ X_np = data (rng ) if callable (data ) else data
211
+ weights_np = weights (rng ) if callable (weights ) else weights
212
+ # Ensure `data` of correct dtype
213
+ X_np = X_np .astype (dtype_name )
214
+
215
+ result_np = _weighted_percentile (X_np , weights_np , percentile )
216
+ # Convert to Array API arrays
217
+ X_xp = xp .asarray (X_np , device = device )
218
+ weights_xp = xp .asarray (weights_np , device = device )
219
+
220
+ with config_context (array_api_dispatch = True ):
221
+ result_xp = _weighted_percentile (X_xp , weights_xp , percentile )
222
+ assert array_device (result_xp ) == array_device (X_xp )
223
+ assert get_namespace (result_xp )[0 ] == get_namespace (X_xp )[0 ]
224
+ result_xp_np = _convert_to_numpy (result_xp , xp = xp )
225
+
226
+ assert result_xp_np .dtype == result_np .dtype
227
+ assert result_xp_np .shape == result_np .shape
228
+ assert_allclose (result_np , result_xp_np )
229
+
230
+ # Check dtype correct (`sample_weight` should follow `array`)
231
+ if dtype_name == "float32" :
232
+ assert result_xp_np .dtype == result_np .dtype == np .float32
233
+ else :
234
+ assert result_xp_np .dtype == np .float64
235
+
236
+
140
237
@pytest .mark .parametrize ("sample_weight_ndim" , [1 , 2 ])
141
238
def test_weighted_percentile_nan_filtered (sample_weight_ndim ):
142
239
"""Test that calling _weighted_percentile on an array with nan values returns
0 commit comments