Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 479f098

Browse filesBrowse files
committed
ENH: Forbid (0, 1) weights in weighted quantile (#9211)
Various other changes pertaining to api signatures.
1 parent 01968b1 commit 479f098
Copy full SHA for 479f098

File tree

Expand file treeCollapse file tree

4 files changed

+55
-85
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+55
-85
lines changed

‎numpy/lib/_function_base_impl.py

Copy file name to clipboardExpand all lines: numpy/lib/_function_base_impl.py
+33-37Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,7 +2010,7 @@ def disp(mesg, device=None, linefeed=True):
20102010
"(deprecated in NumPy 2.0)",
20112011
DeprecationWarning,
20122012
stacklevel=2
2013-
)
2013+
)
20142014

20152015
if device is None:
20162016
device = sys.stdout
@@ -3929,22 +3929,22 @@ def _median(a, axis=None, out=None, overwrite_input=False):
39293929
return rout
39303930

39313931

3932-
def _percentile_dispatcher(a, q, axis=None, weights=None, out=None,
3932+
def _percentile_dispatcher(a, q, axis=None, out=None,
39333933
overwrite_input=None, method=None, keepdims=None, *,
3934-
interpolation=None):
3934+
weights=None, interpolation=None):
39353935
return (a, q, out)
39363936

39373937

39383938
@array_function_dispatch(_percentile_dispatcher)
39393939
def percentile(a,
39403940
q,
39413941
axis=None,
3942-
weights=None,
39433942
out=None,
39443943
overwrite_input=False,
39453944
method="linear",
39463945
keepdims=False,
39473946
*,
3947+
weights=None,
39483948
interpolation=None):
39493949
"""
39503950
Compute the q-th percentile of the data along the specified axis.
@@ -4266,22 +4266,22 @@ def percentile(a,
42664266
a, q, axis, weights, out, overwrite_input, method, keepdims)
42674267

42684268

4269-
def _quantile_dispatcher(a, q, axis=None, weights=None, out=None,
4269+
def _quantile_dispatcher(a, q, axis=None, out=None,
42704270
overwrite_input=None, method=None, keepdims=None, *,
4271-
interpolation=None):
4271+
weights=None, interpolation=None):
42724272
return (a, q, out)
42734273

42744274

42754275
@array_function_dispatch(_quantile_dispatcher)
42764276
def quantile(a,
42774277
q,
42784278
axis=None,
4279-
weights=None,
42804279
out=None,
42814280
overwrite_input=False,
42824281
method="linear",
42834282
keepdims=False,
42844283
*,
4284+
weights=None,
42854285
interpolation=None):
42864286
"""
42874287
Compute the q-th quantile of the data along the specified axis.
@@ -4599,11 +4599,10 @@ def _validate_and_ureduce_weights(a, axis, wgts):
45994599
46004600
Weights cannot:
46014601
* be negative
4602+
* be (0, 1)
46024603
* sum to 0
46034604
However, they can be
46044605
* 0, as long as they do not sum to 0
4605-
* less than 1. In this case, all weights are re-normalized by
4606-
the lowest non-zero weight prior to computation.
46074606
46084607
Weights will be broadcasted to the shape of a, then reduced as done
46094608
via _ureduce().
@@ -4630,6 +4629,9 @@ def _validate_and_ureduce_weights(a, axis, wgts):
46304629
if (wgts < 0).any():
46314630
raise ValueError("Negative weight not allowed.")
46324631

4632+
if ((0 < wgts) & (wgts < 1)).any():
4633+
raise ValueError("Partial weight (0, 1) not allowed.")
4634+
46334635
# dims to reshape to, before broadcast
46344636
if axis is None:
46354637
dims = tuple(range(a.ndim)) # all axes
@@ -4666,22 +4668,6 @@ def _validate_and_ureduce_weights(a, axis, wgts):
46664668
# Obtain a weights array of the same shape as ureduced a
46674669
wgts = _ureduce(wgts, func=lambda x, **kwargs: x, axis=dims)
46684670

4669-
# Now check/renormalize weights if any is (0, 1)
4670-
def _normalize(v):
4671-
inds = v > 0
4672-
if (v[inds] < 1).any():
4673-
vec = v.copy()
4674-
vec[inds] = vec[inds] / vec[inds].min() # renormalization
4675-
return vec
4676-
else:
4677-
return v
4678-
4679-
# perform normalization along reduced axis
4680-
if len(dims) > 1:
4681-
wgts = np.apply_along_axis(_normalize, -1, wgts)
4682-
else:
4683-
wgts = np.apply_along_axis(_normalize, dims[0], wgts)
4684-
46854671
return wgts
46864672

46874673

@@ -4976,11 +4962,14 @@ def _get_weighted_quantile_values(arr1d, wgts1d):
49764962

49774963
# each weight occupies a range in weight space w/ left/right bounds
49784964
left_weight_bound = np.roll(wgts1d_cumsum, 1)
4979-
left_weight_bound[0] = 0 # left-most weight bound fixed at 0
4980-
right_weight_bound = wgts1d_cumsum - 1
4965+
# value i left weight index bound = sum(weights before i) + 1 - 1,
4966+
# the +1 due to neighboring values having an index distance of 1,
4967+
# the -1 due to 0-indexing in Python
4968+
left_weight_bound[0] = 0 # left-most weight bound defined to be 0
4969+
right_weight_bound = wgts1d_cumsum - 1 # -1 due to 0-indexing
49814970

49824971
# now construct a mapping from weight bounds to real indexes
4983-
# for example, arr1d=[1, 2] & wgts1d=[2, 3] ->
4972+
# arr1d=[7, 8] & wgts1d=[2, 3] == [7, 7, 8, 8, 8]
49844973
# -> real_indexes=[0, 0, 1, 1] & w_index_bounds=[0, 1, 2, 4]
49854974
indexes = np.arange(arr1d.size)
49864975
real_indexes = np.zeros(2 * indexes.size)
@@ -4993,29 +4982,35 @@ def _get_weighted_quantile_values(arr1d, wgts1d):
49934982
# first define previous_w_indexes/next_w_indexes as the indexes
49944983
# within w_index_bounds whose values sandwich weight_space_indexes.
49954984
# so if w_index_bounds=[0, 1, 2, 4] and weight_space_index=3.5,
4996-
# then previous_w_indexes = 2 and next_w_indexes = 3
4985+
# then previous_w_indexes = 2 and next_w_indexes = 3,
4986+
# meaning weight_space_indexed is sandwiched by w_index_bounds[2]
4987+
# and w_index_bounds[3].
49974988
previous_w_indexes = np.searchsorted(w_index_bounds,
49984989
weight_space_indexes,
49994990
side="right") - 1
50004991
# leverage _get_index() to deal with out-of-bound indices
50014992
previous_w_indexes, next_w_indexes =\
50024993
_get_indexes(w_index_bounds, previous_w_indexes,
50034994
len(w_index_bounds))
5004-
# now redefine previous_w_indexes/next_w_indexes as the weight
5005-
# space indexes that neighbor weight_space_indexes.
4995+
# following earlier example, we now know weight_space_indexed is
4996+
# sandwiched by w_index_bounds[2] and w_index_bounds[3], which are
4997+
# 2 and 4. We want the 2 and 4.
4998+
# so redefine previous_w_indexes/next_w_indexes as the
4999+
# w_index_bounds that neighbor weight_space_indexes.
50065000
previous_w_indexes = w_index_bounds[previous_w_indexes]
50075001
next_w_indexes = w_index_bounds[next_w_indexes]
50085002

5009-
# map all weight space indexes to real indexes, then compute gamma
5003+
# method-dependent gammas determine interpolation scheme between
5004+
# neighboring values, and are computed in weight space.
5005+
gamma =\
5006+
_get_gamma(weight_space_indexes, previous_w_indexes, method)
5007+
5008+
# map all weight space indexes to real indexes
50105009
previous_indexes =\
50115010
np.interp(previous_w_indexes, w_index_bounds, real_indexes)
50125011
next_indexes =\
50135012
np.interp(next_w_indexes, w_index_bounds, real_indexes)
50145013

5015-
# method-dependent gammas determine interpolation scheme between
5016-
# neighboring values, and are computed in weight space.
5017-
gamma =\
5018-
_get_gamma(weight_space_indexes, previous_w_indexes, method)
50195014
previous = take(arr1d, previous_indexes.astype(int))
50205015
next = take(arr1d, next_indexes.astype(int))
50215016
return _lerp(previous, next, gamma, out=out)
@@ -5029,7 +5024,8 @@ def _get_weighted_quantile_values(arr1d, wgts1d):
50295024
result = get_weighted_quantile_values(arr, weights)
50305025

50315026
# now move data to DATA_AXIS to be consistent with no-weights case
5032-
result = np.moveaxis(result, -1, destination=0)
5027+
if axis != -1 and quantiles.ndim:
5028+
result = np.moveaxis(result, -1, destination=0)
50335029

50345030
else:
50355031
values_count = arr.shape[axis]

‎numpy/lib/_nanfunctions_impl.py

Copy file name to clipboardExpand all lines: numpy/lib/_nanfunctions_impl.py
+9-13Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,8 +1219,8 @@ def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=np._NoValu
12191219

12201220

12211221
def _nanpercentile_dispatcher(
1222-
a, q, axis=None, weights=None, out=None, overwrite_input=None,
1223-
method=None, keepdims=None, *, interpolation=None):
1222+
a, q, axis=None, out=None, overwrite_input=None,
1223+
method=None, keepdims=None, *, weights=None, interpolation=None):
12241224
return (a, q, out)
12251225

12261226

@@ -1229,12 +1229,12 @@ def nanpercentile(
12291229
a,
12301230
q,
12311231
axis=None,
1232-
weights=None,
12331232
out=None,
12341233
overwrite_input=False,
12351234
method="linear",
12361235
keepdims=np._NoValue,
12371236
*,
1237+
weights=None,
12381238
interpolation=None,
12391239
):
12401240
"""
@@ -1385,9 +1385,9 @@ def nanpercentile(
13851385
a, q, axis, weights, out, overwrite_input, method, keepdims)
13861386

13871387

1388-
def _nanquantile_dispatcher(a, q, axis=None, weights=None, out=None,
1388+
def _nanquantile_dispatcher(a, q, axis=None, out=None,
13891389
overwrite_input=None, method=None, keepdims=None,
1390-
*, interpolation=None):
1390+
*, weights=None, interpolation=None):
13911391
return (a, q, out)
13921392

13931393

@@ -1396,12 +1396,12 @@ def nanquantile(
13961396
a,
13971397
q,
13981398
axis=None,
1399-
weights=None,
14001399
out=None,
14011400
overwrite_input=False,
14021401
method="linear",
14031402
keepdims=np._NoValue,
14041403
*,
1404+
weights=None,
14051405
interpolation=None,
14061406
):
14071407
"""
@@ -1570,12 +1570,7 @@ def _nanquantile_unchecked(
15701570
return np.nanmean(a, axis, out=out, keepdims=keepdims)
15711571

15721572
if weights is not None:
1573-
<<<<<<< HEAD
15741573
weights = fnb._validate_and_ureduce_weights(a, axis, weights)
1575-
weights[np.isnan(a)] = np.nan # for _nanquantile_1d
1576-
=======
1577-
weights = function_base._validate_and_ureduce_weights(a, axis, weights)
1578-
>>>>>>> ENH: Tests and documentation for weights arg to quantile/percentile in lib.function_base and nanquantile/nanpercentile in lib.nanfunctions (#9211).
15791574

15801575
return fnb._ureduce(a,
15811576
func=_nanquantile_ureduce_func,
@@ -1657,8 +1652,9 @@ def _nanquantile_1d(arr1d, q, wgt1d=None, overwrite_input=False,
16571652
# convert to scalar
16581653
return np.full(q.shape, np.nan, dtype=arr1d.dtype)[()]
16591654

1660-
return function_base._quantile_unchecked(
1661-
arr1d, q, overwrite_input=overwrite_input, method=method)
1655+
return fnb._quantile_unchecked(
1656+
arr1d, q, weights=wgt1d, overwrite_input=overwrite_input,
1657+
method=method)
16621658

16631659

16641660
def _nanvar_dispatcher(a, axis=None, dtype=None, out=None, ddof=None,

‎numpy/lib/tests/test_function_base.py

Copy file name to clipboardExpand all lines: numpy/lib/tests/test_function_base.py
+6-17Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3023,10 +3023,10 @@ def test_fraction(self):
30233023

30243024
def test_api(self):
30253025
d = np.ones(5)
3026-
np.percentile(d, 5, None, None, None, False)
3027-
np.percentile(d, 5, None, None, None, False, 'linear')
3026+
np.percentile(d, 5, None, None, False)
3027+
np.percentile(d, 5, None, None, False, 'linear')
30283028
o = np.ones((1,))
3029-
np.percentile(d, 5, None, None, o, False, 'linear')
3029+
np.percentile(d, 5, None, o, False, 'linear')
30303030

30313031
def test_complex(self):
30323032
arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='G')
@@ -3861,15 +3861,8 @@ def test_various_weights(self, method):
38613861
assert_almost_equal(actual, expected)
38623862

38633863
# mix of numeric types
3864-
# due to renormalization triggered by weight < 1,
38653864
# this is expected to be the same as weights = [1, 2, 3]
3866-
weights = [decimal.Decimal(0.5), 1, 1.5]
3867-
actual = np.quantile(ar, q=q, axis=axis, weights=weights,
3868-
method=method)
3869-
assert_almost_equal(actual, expected)
3870-
3871-
# show that normalization means sum of weights is irrelavant
3872-
weights = [0.1, 0.2, 0.3]
3865+
weights = [decimal.Decimal(1.0), 2, 3.0]
38733866
actual = np.quantile(ar, q=q, axis=axis, weights=weights,
38743867
method=method)
38753868
assert_almost_equal(actual, expected)
@@ -3882,12 +3875,6 @@ def test_various_weights(self, method):
38823875
expected = np.quantile(ar_012, q=q, axis=axis, method=method)
38833876
assert_almost_equal(actual, expected)
38843877

3885-
# weight entries < 1
3886-
weights = [0.0, 0.001, 0.002]
3887-
actual = np.quantile(ar, q=q, axis=axis, weights=weights,
3888-
method=method)
3889-
assert_almost_equal(actual, expected)
3890-
38913878
def test_weights_flags(self):
38923879
"""Test that flags are raised on invalid weights."""
38933880
ar = np.arange(6).reshape(2, 3)
@@ -3902,6 +3889,8 @@ def test_weights_flags(self):
39023889
np.quantile(ar, q=q, axis=axis, weights=[1, np.nan])
39033890
with assert_raises_regex(ValueError, "Negative weight not allowed"):
39043891
np.quantile(ar, q=q, axis=axis, weights=[1, -1])
3892+
with assert_raises_regex(ValueError, "Partial weight"):
3893+
np.quantile(ar, q=q, axis=axis, weights=[1, 0.1])
39053894
with assert_raises_regex(ZeroDivisionError, "Weights sum to zero"):
39063895
np.quantile(ar, q=q, axis=axis, weights=[0, 0])
39073896

‎numpy/lib/tests/test_nanfunctions.py

Copy file name to clipboardExpand all lines: numpy/lib/tests/test_nanfunctions.py
+7-18Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55

66
import numpy as np
7-
import numpy.lib.function_base as nfb
7+
from numpy.lib import _function_base_impl as fnb
88
from numpy._core.numeric import normalize_axis_tuple
99
from numpy.exceptions import AxisError, ComplexWarning
1010
from numpy.lib._nanfunctions_impl import _nan_mask, _replace_nan
@@ -1278,7 +1278,7 @@ def test_allnans(self, axis, dtype, array):
12781278
assert np.isnan(out).all()
12791279
assert out.dtype == array.dtype
12801280

1281-
@pytest.mark.parametrize("method", list(nfb._QuantileMethods.keys()))
1281+
@pytest.mark.parametrize("method", list(fnb._QuantileMethods.keys()))
12821282
def test_weights_all_ones(self, method):
12831283
"""Test that all weights == 1 gives same results as no weights."""
12841284
ar = np.arange(24).reshape(2, 3, 4).astype(float)
@@ -1317,7 +1317,7 @@ def test_weights_all_ones(self, method):
13171317
actual = np.nanquantile(ar, q=q, weights=weights, method=method)
13181318
assert_almost_equal(actual, expected)
13191319

1320-
@pytest.mark.parametrize("method", list(nfb._QuantileMethods.keys()))
1320+
@pytest.mark.parametrize("method", list(fnb._QuantileMethods.keys()))
13211321
def test_multiple_axes(self, method):
13221322
"""Test that weights work on multiple axes."""
13231323
ar = np.arange(12).reshape(3, 4).astype(float)
@@ -1331,7 +1331,7 @@ def test_multiple_axes(self, method):
13311331
method=method)
13321332
assert_almost_equal(actual, expected)
13331333

1334-
@pytest.mark.parametrize("method", list(nfb._QuantileMethods.keys()))
1334+
@pytest.mark.parametrize("method", list(fnb._QuantileMethods.keys()))
13351335
def test_various_weights(self, method):
13361336
"""Test various weights arg scenarios."""
13371337
ar = np.arange(12).reshape(3, 4).astype(float)
@@ -1357,15 +1357,8 @@ def test_various_weights(self, method):
13571357
assert_almost_equal(actual, expected)
13581358

13591359
# mix of numeric types
1360-
# due to renormalization triggered by weight < 1,
13611360
# this is expected to be the same as weights = [1, 2, 3]
1362-
weights = [decimal.Decimal(0.5), 1, 1.5]
1363-
actual = np.nanquantile(ar, q=q, axis=axis, weights=weights,
1364-
method=method)
1365-
assert_almost_equal(actual, expected)
1366-
1367-
# show that normalization means sum of weights is irrelavant
1368-
weights = [0.2, 0.4, 0.6]
1361+
weights = [decimal.Decimal(1.0), 2, 3.0]
13691362
actual = np.nanquantile(ar, q=q, axis=axis, weights=weights,
13701363
method=method)
13711364
assert_almost_equal(actual, expected)
@@ -1378,12 +1371,6 @@ def test_various_weights(self, method):
13781371
expected = np.nanquantile(ar_012, q=q, axis=axis, method=method)
13791372
assert_almost_equal(actual, expected)
13801373

1381-
# weight entries < 1
1382-
weights = [0.0, 0.001, 0.002]
1383-
actual = np.nanquantile(ar, q=q, axis=axis, weights=weights,
1384-
method=method)
1385-
assert_almost_equal(actual, expected)
1386-
13871374
def test_weights_flags(self):
13881375
"""Test that flags are raised on invalid weights."""
13891376
ar = np.arange(6).reshape(2, 3).astype(float)
@@ -1399,6 +1386,8 @@ def test_weights_flags(self):
13991386
np.quantile(ar, q=q, axis=axis, weights=[1, np.nan])
14001387
with assert_raises_regex(ValueError, "Negative weight not allowed"):
14011388
np.quantile(ar, q=q, axis=axis, weights=[1, -1])
1389+
with assert_raises_regex(ValueError, "Partial weight"):
1390+
np.quantile(ar, q=q, axis=axis, weights=[1, 0.1])
14021391
with assert_raises_regex(ZeroDivisionError, "Weights sum to zero"):
14031392
np.quantile(ar, q=q, axis=axis, weights=[0, 0])
14041393

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.