Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 775587b

Browse filesBrowse files
authored
FEA Support missing-values in ExtraTrees* (#28268)
1 parent 4cc331f commit 775587b
Copy full SHA for 775587b

File tree

3 files changed

+42
-5
lines changed
Filter options

3 files changed

+42
-5
lines changed

‎doc/whats_new/v1.6.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.6.rst
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ Changelog
148148
:pr:`28622` by :user:`Adam Li <adam2392>` and
149149
:user:`Sérgio Pereira <sergiormpereira>`.
150150

151+
- |Feature| :class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor` now support
152+
missing-values in the data matrix `X`. Missing-values are handled by randomly moving all of
153+
the samples to the left, or right child node as the tree is traversed.
154+
:pr:`28268` by :user:`Adam Li <adam2392>`.
155+
151156
:mod:`sklearn.impute`
152157
.....................
153158

‎sklearn/ensemble/tests/test_forest.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/tests/test_forest.py
+17-5Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,6 +1767,8 @@ def test_estimators_samples(ForestClass, bootstrap, seed):
17671767
[
17681768
(datasets.make_regression, RandomForestRegressor),
17691769
(datasets.make_classification, RandomForestClassifier),
1770+
(datasets.make_regression, ExtraTreesRegressor),
1771+
(datasets.make_classification, ExtraTreesClassifier),
17701772
],
17711773
)
17721774
def test_missing_values_is_resilient(make_data, Forest):
@@ -1800,12 +1802,21 @@ def test_missing_values_is_resilient(make_data, Forest):
18001802
assert score_with_missing >= 0.80 * score_without_missing
18011803

18021804

1803-
@pytest.mark.parametrize("Forest", [RandomForestClassifier, RandomForestRegressor])
1805+
@pytest.mark.parametrize(
1806+
"Forest",
1807+
[
1808+
RandomForestClassifier,
1809+
RandomForestRegressor,
1810+
ExtraTreesRegressor,
1811+
ExtraTreesClassifier,
1812+
],
1813+
)
18041814
def test_missing_value_is_predictive(Forest):
18051815
"""Check that the forest learns when missing values are only present for
18061816
a predictive feature."""
18071817
rng = np.random.RandomState(0)
18081818
n_samples = 300
1819+
expected_score = 0.75
18091820

18101821
X_non_predictive = rng.standard_normal(size=(n_samples, 10))
18111822
y = rng.randint(0, high=2, size=n_samples)
@@ -1835,19 +1846,20 @@ def test_missing_value_is_predictive(Forest):
18351846

18361847
predictive_test_score = forest_predictive.score(X_predictive_test, y_test)
18371848

1838-
assert predictive_test_score >= 0.75
1849+
assert predictive_test_score >= expected_score
18391850
assert predictive_test_score >= forest_non_predictive.score(
18401851
X_non_predictive_test, y_test
18411852
)
18421853

18431854

1844-
def test_non_supported_criterion_raises_error_with_missing_values():
1855+
@pytest.mark.parametrize("Forest", FOREST_REGRESSORS.values())
1856+
def test_non_supported_criterion_raises_error_with_missing_values(Forest):
18451857
"""Raise error for unsupported criterion when there are missing values."""
18461858
X = np.array([[0, 1, 2], [np.nan, 0, 2.0]])
18471859
y = [0.5, 1.0]
18481860

1849-
forest = RandomForestRegressor(criterion="absolute_error")
1861+
forest = Forest(criterion="absolute_error")
18501862

1851-
msg = "RandomForestRegressor does not accept missing values"
1863+
msg = ".*does not accept missing values"
18521864
with pytest.raises(ValueError, match=msg):
18531865
forest.fit(X, y)

‎sklearn/tree/_classes.py

Copy file name to clipboardExpand all lines: sklearn/tree/_classes.py
+20Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,6 +1686,16 @@ def __init__(
16861686
monotonic_cst=monotonic_cst,
16871687
)
16881688

1689+
def _more_tags(self):
1690+
# XXX: nan is only supported for dense arrays, but we set this for the
1691+
# common test to pass, specifically: check_estimators_nan_inf
1692+
allow_nan = self.splitter == "random" and self.criterion in {
1693+
"gini",
1694+
"log_loss",
1695+
"entropy",
1696+
}
1697+
return {"multilabel": True, "allow_nan": allow_nan}
1698+
16891699

16901700
class ExtraTreeRegressor(DecisionTreeRegressor):
16911701
"""An extremely randomized tree regressor.
@@ -1929,3 +1939,13 @@ def __init__(
19291939
ccp_alpha=ccp_alpha,
19301940
monotonic_cst=monotonic_cst,
19311941
)
1942+
1943+
def _more_tags(self):
1944+
# XXX: nan is only supported for dense arrays, but we set this for the
1945+
# common test to pass, specifically: check_estimators_nan_inf
1946+
allow_nan = self.splitter == "random" and self.criterion in {
1947+
"squared_error",
1948+
"friedman_mse",
1949+
"poisson",
1950+
}
1951+
return {"allow_nan": allow_nan}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.