Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 419458d

Browse filesBrowse files
thomasjpfanbetatimjjerphan
authored andcommitted
ENH Adds support for missing values in Random Forest (scikit-learn#26391)
Co-authored-by: Tim Head <betatim@gmail.com> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 91e4bbd commit 419458d
Copy full SHA for 419458d

File tree

4 files changed

+153
-6
lines changed
Filter options

4 files changed

+153
-6
lines changed

‎doc/whats_new/v1.4.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.4.rst
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ Changelog
9292
:mod:`sklearn.ensemble`
9393
.......................
9494

95+
- |MajorFeature| :class:`ensemble.RandomForestClassifier` and
96+
:class:`ensemble.RandomForestRegressor` support missing values when
97+
the criterion is `gini`, `entropy`, or `log_loss`,
98+
for classification or `squared_error`, `friedman_mse`, or `poisson`
99+
for regression. :pr:`26391` by `Thomas Fan`_.
100+
95101
- |Feature| :class:`ensemble.RandomForestClassifier`,
96102
:class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier`
97103
and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints,

‎sklearn/ensemble/_forest.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/_forest.py
+53-4Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
7070
from ..tree._tree import DOUBLE, DTYPE
7171
from ..utils import check_random_state, compute_sample_weight
7272
from ..utils._param_validation import Interval, RealNotInt, StrOptions
73+
from ..utils._tags import _safe_tags
7374
from ..utils.multiclass import check_classification_targets, type_of_target
7475
from ..utils.parallel import Parallel, delayed
7576
from ..utils.validation import (
@@ -159,6 +160,7 @@ def _parallel_build_trees(
159160
verbose=0,
160161
class_weight=None,
161162
n_samples_bootstrap=None,
163+
missing_values_in_feature_mask=None,
162164
):
163165
"""
164166
Private function used to fit a single tree in parallel."""
@@ -185,9 +187,21 @@ def _parallel_build_trees(
185187
elif class_weight == "balanced_subsample":
186188
curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices)
187189

188-
tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
190+
tree._fit(
191+
X,
192+
y,
193+
sample_weight=curr_sample_weight,
194+
check_input=False,
195+
missing_values_in_feature_mask=missing_values_in_feature_mask,
196+
)
189197
else:
190-
tree.fit(X, y, sample_weight=sample_weight, check_input=False)
198+
tree._fit(
199+
X,
200+
y,
201+
sample_weight=sample_weight,
202+
check_input=False,
203+
missing_values_in_feature_mask=missing_values_in_feature_mask,
204+
)
191205

192206
return tree
193207

@@ -345,9 +359,26 @@ def fit(self, X, y, sample_weight=None):
345359
# Validate or convert input data
346360
if issparse(y):
347361
raise ValueError("sparse multilabel-indicator for y is not supported.")
362+
348363
X, y = self._validate_data(
349-
X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE
364+
X,
365+
y,
366+
multi_output=True,
367+
accept_sparse="csc",
368+
dtype=DTYPE,
369+
force_all_finite=False,
370+
)
371+
# _compute_missing_values_in_feature_mask checks if X has missing values and
372+
# will raise an error if the underlying tree base estimator can't handle missing
373+
# values. Only the criterion is required to determine if the tree supports
374+
# missing values.
375+
estimator = type(self.estimator)(criterion=self.criterion)
376+
missing_values_in_feature_mask = (
377+
estimator._compute_missing_values_in_feature_mask(
378+
X, estimator_name=self.__class__.__name__
379+
)
350380
)
381+
351382
if sample_weight is not None:
352383
sample_weight = _check_sample_weight(sample_weight, X)
353384

@@ -469,6 +500,7 @@ def fit(self, X, y, sample_weight=None):
469500
verbose=self.verbose,
470501
class_weight=self.class_weight,
471502
n_samples_bootstrap=n_samples_bootstrap,
503+
missing_values_in_feature_mask=missing_values_in_feature_mask,
472504
)
473505
for i, t in enumerate(trees)
474506
)
@@ -596,7 +628,18 @@ def _validate_X_predict(self, X):
596628
"""
597629
Validate X whenever one tries to predict, apply, predict_proba."""
598630
check_is_fitted(self)
599-
X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False)
631+
if self.estimators_[0]._support_missing_values(X):
632+
force_all_finite = "allow-nan"
633+
else:
634+
force_all_finite = True
635+
636+
X = self._validate_data(
637+
X,
638+
dtype=DTYPE,
639+
accept_sparse="csr",
640+
reset=False,
641+
force_all_finite=force_all_finite,
642+
)
600643
if issparse(X) and (X.indices.dtype != np.intc or X.indptr.dtype != np.intc):
601644
raise ValueError("No support for np.int64 index based sparse matrices")
602645
return X
@@ -636,6 +679,12 @@ def feature_importances_(self):
636679
all_importances = np.mean(all_importances, axis=0, dtype=np.float64)
637680
return all_importances / np.sum(all_importances)
638681

682+
def _more_tags(self):
683+
# Only the criterion is required to determine if the tree supports
684+
# missing values
685+
estimator = type(self.estimator)(criterion=self.criterion)
686+
return {"allow_nan": _safe_tags(estimator, key="allow_nan")}
687+
639688

640689
def _accumulate_prediction(predict, X, out, lock):
641690
"""

‎sklearn/ensemble/tests/test_forest.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/tests/test_forest.py
+88Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,3 +1809,91 @@ def test_round_samples_to_one_when_samples_too_low(class_weight):
18091809
n_estimators=10, max_samples=1e-4, class_weight=class_weight, random_state=0
18101810
)
18111811
forest.fit(X, y)
1812+
1813+
1814+
@pytest.mark.parametrize(
1815+
"make_data, Forest",
1816+
[
1817+
(datasets.make_regression, RandomForestRegressor),
1818+
(datasets.make_classification, RandomForestClassifier),
1819+
],
1820+
)
1821+
def test_missing_values_is_resilient(make_data, Forest):
1822+
"""Check that forest can deal with missing values and have decent performance."""
1823+
1824+
rng = np.random.RandomState(0)
1825+
n_samples, n_features = 1000, 10
1826+
X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng)
1827+
1828+
# Create dataset with missing values
1829+
X_missing = X.copy()
1830+
X_missing[rng.choice([False, True], size=X.shape, p=[0.95, 0.05])] = np.nan
1831+
X_missing_train, X_missing_test, y_train, y_test = train_test_split(
1832+
X_missing, y, random_state=0
1833+
)
1834+
1835+
# Train forest with missing values
1836+
forest_with_missing = Forest(random_state=rng, n_estimators=50)
1837+
forest_with_missing.fit(X_missing_train, y_train)
1838+
score_with_missing = forest_with_missing.score(X_missing_test, y_test)
1839+
1840+
# Train forest without missing values
1841+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
1842+
forest = Forest(random_state=rng, n_estimators=50)
1843+
forest.fit(X_train, y_train)
1844+
score_without_missing = forest.score(X_test, y_test)
1845+
1846+
# Score is still 80 percent of the forest's score that had no missing values
1847+
assert score_with_missing >= 0.80 * score_without_missing
1848+
1849+
1850+
@pytest.mark.parametrize("Forest", [RandomForestClassifier, RandomForestRegressor])
1851+
def test_missing_value_is_predictive(Forest):
1852+
"""Check that the forest learns when missing values are only present for
1853+
a predictive feature."""
1854+
rng = np.random.RandomState(0)
1855+
n_samples = 300
1856+
1857+
X_non_predictive = rng.standard_normal(size=(n_samples, 10))
1858+
y = rng.randint(0, high=2, size=n_samples)
1859+
1860+
# Create a predictive feature using `y` and with some noise
1861+
X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05])
1862+
y_mask = y.astype(bool)
1863+
y_mask[X_random_mask] = ~y_mask[X_random_mask]
1864+
1865+
predictive_feature = rng.standard_normal(size=n_samples)
1866+
predictive_feature[y_mask] = np.nan
1867+
1868+
X_predictive = X_non_predictive.copy()
1869+
X_predictive[:, 5] = predictive_feature
1870+
1871+
(
1872+
X_predictive_train,
1873+
X_predictive_test,
1874+
X_non_predictive_train,
1875+
X_non_predictive_test,
1876+
y_train,
1877+
y_test,
1878+
) = train_test_split(X_predictive, X_non_predictive, y, random_state=0)
1879+
forest_predictive = Forest(random_state=0).fit(X_predictive_train, y_train)
1880+
forest_non_predictive = Forest(random_state=0).fit(X_non_predictive_train, y_train)
1881+
1882+
predictive_test_score = forest_predictive.score(X_predictive_test, y_test)
1883+
1884+
assert predictive_test_score >= 0.75
1885+
assert predictive_test_score >= forest_non_predictive.score(
1886+
X_non_predictive_test, y_test
1887+
)
1888+
1889+
1890+
def test_non_supported_criterion_raises_error_with_missing_values():
1891+
"""Raise error for unsupported criterion when there are missing values."""
1892+
X = np.array([[0, 1, 2], [np.nan, 0, 2.0]])
1893+
y = [0.5, 1.0]
1894+
1895+
forest = RandomForestRegressor(criterion="absolute_error")
1896+
1897+
msg = "RandomForestRegressor does not accept missing values"
1898+
with pytest.raises(ValueError, match=msg):
1899+
forest.fit(X, y)

‎sklearn/tree/_classes.py

Copy file name to clipboardExpand all lines: sklearn/tree/_classes.py
+6-2Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _support_missing_values(self, X):
189189
and self.monotonic_cst is None
190190
)
191191

192-
def _compute_missing_values_in_feature_mask(self, X):
192+
def _compute_missing_values_in_feature_mask(self, X, estimator_name=None):
193193
"""Return boolean mask denoting if there are missing values for each feature.
194194
195195
This method also ensures that X is finite.
@@ -199,13 +199,17 @@ def _compute_missing_values_in_feature_mask(self, X):
199199
X : array-like of shape (n_samples, n_features), dtype=DOUBLE
200200
Input data.
201201
202+
estimator_name : str or None, default=None
203+
Name to use when raising an error. Defaults to the class name.
204+
202205
Returns
203206
-------
204207
missing_values_in_feature_mask : ndarray of shape (n_features,), or None
205208
Missing value mask. If missing values are not supported or there
206209
are no missing values, return None.
207210
"""
208-
common_kwargs = dict(estimator_name=self.__class__.__name__, input_name="X")
211+
estimator_name = estimator_name or self.__class__.__name__
212+
common_kwargs = dict(estimator_name=estimator_name, input_name="X")
209213

210214
if not self._support_missing_values(X):
211215
assert_all_finite(X, **common_kwargs)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.