Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b309f9e

Browse filesBrowse files
author
Michael Recachinas
committed
Add interpolation to _weighted_percentile
1 parent ea115c2 commit b309f9e
Copy full SHA for b309f9e

File tree

Expand file treeCollapse file tree

2 files changed

+156
-8
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+156
-8
lines changed

‎sklearn/utils/stats.py

Copy file name to clipboard
+67-8Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,77 @@
1-
import numpy as np
1+
"""Statistical utilities including weighted percentile"""
22

3+
import numpy as np
34
from sklearn.utils.extmath import stable_cumsum
45

56

67
def _weighted_percentile(array, sample_weight, percentile=50):
8+
"""Compute the weighted ``percentile`` of ``array``
9+
with ``sample_weight``.
10+
11+
This approach follows
12+
13+
N
14+
S_N = sum w_k
15+
k=1
16+
17+
p_n = 1 / S_N * (x_n - w_n / 2)
18+
19+
v = v_k + (v_{k + 1} - v_k) * (P - p_k) / (p_{k + 1} - p_k)
20+
21+
from
22+
https://en.wikipedia.org/wiki/Percentile#The_weighted_percentile_method.
23+
24+
25+
Parameters
26+
----------
27+
array : array-like, shape = (n_samples,)
28+
Array of data on which to calculate the weighted percentile
29+
30+
sample_weight : array-like, shape = (n_samples,)
31+
Array of corresponding sample weights with which to calculate
32+
the weighted percentile
33+
34+
percentile : int, optional (default: 50)
35+
Integer value of Pth percentile to compute
36+
37+
Returns
38+
-------
39+
v : float
40+
Linearly interpolated weighted percentile.
41+
42+
Examples
43+
--------
44+
>>> import numpy as np
45+
>>> from sklearn.utils.stats import _weighted_percentile
46+
>>> weight = np.array([1, 1])
47+
>>> data = np.array([0, 1])
48+
>>> _weighted_percentile(data, weight, percentile=0)
49+
0.0
50+
>>> _weighted_percentile(data, weight, percentile=50)
51+
0.5
52+
>>> _weighted_percentile(data, weight, percentile=90)
53+
1.0
754
"""
8-
Compute the weighted ``percentile`` of ``array`` with ``sample_weight``.
9-
"""
55+
if not isinstance(array, np.ndarray):
56+
array = np.array(array)
57+
58+
if not isinstance(sample_weight, np.ndarray):
59+
sample_weight = np.array(sample_weight)
60+
61+
if (sample_weight < 0).any():
62+
raise ValueError("sample_weight must contain positive or 0 weights")
63+
64+
if percentile < 0:
65+
raise ValueError("percentile must be positive or 0")
66+
1067
sorted_idx = np.argsort(array)
68+
sorted_array = array[sorted_idx]
69+
70+
# if there are no weights, return the min of ``array``
71+
if sample_weight.sum() == 0:
72+
return sorted_array[0]
1173

1274
# Find index of median prediction for each sample
1375
weight_cdf = stable_cumsum(sample_weight[sorted_idx])
14-
percentile_idx = np.searchsorted(
15-
weight_cdf, (percentile / 100.) * weight_cdf[-1])
16-
# in rare cases, percentile_idx equals to len(sorted_idx)
17-
percentile_idx = np.clip(percentile_idx, 0, len(sorted_idx)-1)
18-
return array[sorted_idx[percentile_idx]]
76+
p_n = 100. / weight_cdf[-1] * (weight_cdf - sample_weight / 2.)
77+
return np.interp(percentile, p_n, sorted_array)

‎sklearn/utils/tests/test_stats.py

Copy file name to clipboard
+89Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
from sklearn.utils.testing import assert_equal, assert_raises
3+
from sklearn.utils.stats import _weighted_percentile
4+
5+
6+
def test_weighted_percentile_negative_weights_raises():
7+
weight = np.array([1, -1])
8+
data = np.array([0, 1])
9+
assert_raises(ValueError, _weighted_percentile, data, weight)
10+
11+
12+
def test_weighted_percentile_negative_percentile_raises():
13+
weight = np.array([1, -1])
14+
data = np.array([0, 1])
15+
percentile = -50
16+
assert_raises(ValueError, _weighted_percentile, data, weight,
17+
percentile=percentile)
18+
19+
20+
def test_weighted_percentile_median_interpolated_list():
21+
weight = [1, 1]
22+
data = [0, 1]
23+
percentile = 50
24+
expected = 0.5
25+
actual = _weighted_percentile(data, weight, percentile=percentile)
26+
assert_equal(expected, actual)
27+
28+
29+
def test_weighted_percentile_median_interpolated_tuple():
30+
weight = (1, 1)
31+
data = (0, 1)
32+
percentile = 50
33+
expected = 0.5
34+
actual = _weighted_percentile(data, weight, percentile=percentile)
35+
assert_equal(expected, actual)
36+
37+
38+
def test_weighted_percentile_median_interpolated():
39+
weight = np.array([1, 1])
40+
data = np.array([0, 1])
41+
percentile = 50
42+
expected = 0.5
43+
actual = _weighted_percentile(data, weight, percentile=percentile)
44+
assert_equal(expected, actual)
45+
46+
47+
def test_weighted_percentile_median_regular():
48+
weight = np.array([1, 1, 1])
49+
data = np.array([0, 1, 2])
50+
percentile = 50
51+
expected = 1.0
52+
actual = _weighted_percentile(data, weight, percentile=percentile)
53+
assert_equal(expected, actual)
54+
55+
56+
def test_weighted_percentile_0_regular():
57+
weight = np.array([1, 1, 1])
58+
data = np.array([0, 1, 2])
59+
percentile = 0
60+
expected = 0.0
61+
actual = _weighted_percentile(data, weight, percentile=percentile)
62+
assert_equal(expected, actual)
63+
64+
65+
def test_weighted_percentile_90_regular():
66+
weight = np.array([1, 1, 1])
67+
data = np.array([0, 1, 2])
68+
percentile = 90
69+
expected = 2.0
70+
actual = _weighted_percentile(data, weight, percentile=percentile)
71+
assert_equal(expected, actual)
72+
73+
74+
def test_weighted_percentile_70_interpolated():
75+
weight = np.array([1, 1, 1, 1])
76+
data = np.arange(0, 4, 1)
77+
percentile = 70
78+
expected = 2.3
79+
actual = _weighted_percentile(data, weight, percentile=percentile)
80+
assert_equal(expected, actual)
81+
82+
83+
def test_weighted_percentile_70_mixed_weights():
84+
weight = np.array([1, 0, 1, 1])
85+
data = np.arange(0, 4, 1)
86+
percentile = 50
87+
expected = 2.0
88+
actual = _weighted_percentile(data, weight, percentile=percentile)
89+
assert_equal(expected, actual)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.