Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 241de66

Browse filesBrowse files
author
Gaetan
committed
add sample weight tests
1 parent 6fcf61c commit 241de66
Copy full SHA for 241de66

File tree

1 file changed

+160
-1
lines changed
Filter options

1 file changed

+160
-1
lines changed

‎sklearn/ensemble/tests/test_forest.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/tests/test_forest.py
+160-1Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
import sklearn
2424
from sklearn import clone, datasets
25-
from sklearn.datasets import make_classification, make_hastie_10_2
25+
from sklearn.base import is_classifier
26+
from sklearn.datasets import make_classification, make_hastie_10_2, make_regression
2627
from sklearn.decomposition import TruncatedSVD
2728
from sklearn.dummy import DummyRegressor
2829
from sklearn.ensemble import (
@@ -46,6 +47,7 @@
4647
from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split
4748
from sklearn.svm import LinearSVC
4849
from sklearn.tree._classes import SPARSE_SPLITTERS
50+
from sklearn.utils import shuffle
4951
from sklearn.utils._testing import (
5052
_convert_container,
5153
assert_allclose,
@@ -55,6 +57,10 @@
5557
ignore_warnings,
5658
skip_if_no_parallel,
5759
)
60+
from sklearn.utils.estimator_checks import (
61+
_enforce_estimator_tags_X,
62+
_enforce_estimator_tags_y,
63+
)
5864
from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS
5965
from sklearn.utils.multiclass import type_of_target
6066
from sklearn.utils.parallel import Parallel
@@ -1973,6 +1979,159 @@ def test_importance_reg_match_onehot_classi(global_random_seed):
19731979
)
19741980

19751981

1982+
@pytest.mark.parametrize("est_name", FOREST_CLASSIFIERS_REGRESSORS)
1983+
def test_feature_importance_with_sample_weights(est_name, global_random_seed):
1984+
# From https://github.com/snath-xoc/sample-weight-audit-nondet/blob/main/src/sample_weight_audit/data.py#L53
1985+
1986+
# Strategy: sample 2 datasets, each with n_features // 2:
1987+
# - the first one has int(0.8 * n_samples) but mostly zero or one weights.
1988+
# - the second one has the remaining samples but with higher weights.
1989+
#
1990+
# The features of the two datasets are horizontally stacked with random
1991+
# feature values sampled independently from the other dataset. Then the two
1992+
# datasets are vertically stacked and the result is shuffled.
1993+
#
1994+
# The sum of weights of the second dataset is 10 times the sum of weights of
1995+
# the first dataset so that weight aware estimators should mostly ignore the
1996+
# features of the first dataset to learn their prediction function.
1997+
n_samples = 250
1998+
n_features = 4
1999+
n_classes = 2
2000+
max_sample_weight = 5
2001+
2002+
rng = check_random_state(global_random_seed)
2003+
n_samples_sw = int(0.5 * n_samples) # small weights
2004+
n_samples_lw = n_samples - n_samples_sw # large weights
2005+
n_features_sw = n_features // 2
2006+
n_features_lw = n_features - n_features_sw
2007+
2008+
# Construct the sample weights: mostly zeros and some ones for the first
2009+
# dataset, and some random integers larger than one for the second dataset.
2010+
sample_weight_sw = np.where(rng.random(n_samples_sw) < 0.2, 1, 0)
2011+
sample_weight_lw = rng.randint(2, max_sample_weight, size=n_samples_lw)
2012+
total_weight_sum = np.sum(sample_weight_sw) + np.sum(sample_weight_lw)
2013+
assert np.sum(sample_weight_sw) < 0.3 * total_weight_sum
2014+
2015+
est = FOREST_CLASSIFIERS_REGRESSORS[est_name](
2016+
n_estimators=50,
2017+
bootstrap=True,
2018+
oob_score=True,
2019+
random_state=rng,
2020+
)
2021+
if not is_classifier(est):
2022+
X_sw, y_sw = make_regression(
2023+
n_samples=n_samples_sw,
2024+
n_features=n_features_sw,
2025+
random_state=rng,
2026+
)
2027+
X_lw, y_lw = make_regression(
2028+
n_samples=n_samples_lw,
2029+
n_features=n_features_lw,
2030+
random_state=rng, # rng is different because mutated
2031+
)
2032+
else:
2033+
X_sw, y_sw = make_classification(
2034+
n_samples=n_samples_sw,
2035+
n_features=n_features_sw,
2036+
n_informative=n_features_sw,
2037+
n_redundant=0,
2038+
n_repeated=0,
2039+
n_classes=n_classes,
2040+
random_state=rng,
2041+
)
2042+
X_lw, y_lw = make_classification(
2043+
n_samples=n_samples_lw,
2044+
n_features=n_features_lw,
2045+
n_informative=n_features_lw,
2046+
n_redundant=0,
2047+
n_repeated=0,
2048+
n_classes=n_classes,
2049+
random_state=rng, # rng is different because mutated
2050+
)
2051+
2052+
# Horizontally pad the features with features values marginally sampled
2053+
# from the other dataset.
2054+
pad_sw_idx = rng.choice(n_samples_lw, size=n_samples_sw, replace=True)
2055+
X_sw_padded = np.hstack([X_sw, np.take(X_lw, pad_sw_idx, axis=0)])
2056+
2057+
pad_lw_idx = rng.choice(n_samples_sw, size=n_samples_lw, replace=True)
2058+
X_lw_padded = np.hstack([np.take(X_sw, pad_lw_idx, axis=0), X_lw])
2059+
2060+
# Vertically stack the two datasets and shuffle them.
2061+
X = np.concatenate([X_sw_padded, X_lw_padded], axis=0)
2062+
y = np.concatenate([y_sw, y_lw])
2063+
2064+
X = _enforce_estimator_tags_X(est, X)
2065+
y = _enforce_estimator_tags_y(est, y)
2066+
sample_weight = np.concatenate([sample_weight_sw, sample_weight_lw])
2067+
X, y, sample_weight = shuffle(X, y, sample_weight, random_state=rng)
2068+
2069+
est.fit(X, y, sample_weight)
2070+
2071+
ufi_feature_importance = est.ufi_feature_importances_
2072+
mdi_oob_feature_importance = est.mdi_oob_feature_importances_
2073+
assert (
2074+
ufi_feature_importance[:n_features_sw].sum()
2075+
< ufi_feature_importance[n_features_sw:].sum()
2076+
)
2077+
assert (
2078+
mdi_oob_feature_importance[:n_features_sw].sum()
2079+
< mdi_oob_feature_importance[n_features_sw:].sum()
2080+
)
2081+
2082+
2083+
@pytest.mark.parametrize("est_name", FOREST_CLASSIFIERS_REGRESSORS)
2084+
def test_feature_importance_sample_weight_equals_repeated(est_name, global_random_seed):
2085+
# check that setting sample_weight to zero / integer is equivalent
2086+
# to removing / repeating corresponding samples.
2087+
params = dict(
2088+
n_estimators=100,
2089+
bootstrap=True,
2090+
oob_score=True,
2091+
max_features=1.0,
2092+
random_state=global_random_seed,
2093+
)
2094+
2095+
est_weighted = FOREST_CLASSIFIERS_REGRESSORS[est_name](**params)
2096+
est_repeated = FOREST_CLASSIFIERS_REGRESSORS[est_name](**params)
2097+
2098+
n_samples = 100
2099+
n_features = 2
2100+
X, y = make_classification(
2101+
n_samples=n_samples,
2102+
n_features=n_features,
2103+
n_informative=n_features,
2104+
n_redundant=0,
2105+
)
2106+
# Use random integers (including zero) as weights.
2107+
sw = rng.randint(0, 2, size=n_samples)
2108+
2109+
X_weighted = X
2110+
y_weighted = y
2111+
# repeat samples according to weights
2112+
X_repeated = X_weighted.repeat(repeats=sw, axis=0)
2113+
y_repeated = y_weighted.repeat(repeats=sw)
2114+
2115+
X_weighted, y_weighted, sw = shuffle(X_weighted, y_weighted, sw, random_state=0)
2116+
2117+
est_repeated.fit(X_repeated, y=y_repeated, sample_weight=None)
2118+
est_weighted.fit(X_weighted, y=y_weighted, sample_weight=sw)
2119+
2120+
assert_allclose(
2121+
est_repeated.feature_importances_, est_weighted.feature_importances_, atol=1e-1
2122+
)
2123+
assert_allclose(
2124+
est_repeated.ufi_feature_importances_,
2125+
est_weighted.ufi_feature_importances_,
2126+
atol=1e-1,
2127+
)
2128+
assert_allclose(
2129+
est_repeated.mdi_oob_feature_importances_,
2130+
est_weighted.mdi_oob_feature_importances_,
2131+
atol=1e-1,
2132+
)
2133+
2134+
19762135
@pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS)
19772136
def test_max_samples_bootstrap(name):
19782137
# Check invalid `max_samples` values

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.