Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6392148

Browse filesBrowse files
thomasjpfanbetatimjjerphan
authored
ENH Adds missing value support for trees (#23595)
Co-authored-by: Tim Head <betatim@gmail.com> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent ae35ce8 commit 6392148
Copy full SHA for 6392148

File tree

12 files changed

+908
-136
lines changed
Filter options

12 files changed

+908
-136
lines changed

‎doc/modules/tree.rst

Copy file name to clipboardExpand all lines: doc/modules/tree.rst
+59Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,65 @@ Mean Absolute Error:
572572
573573
Note that it fits much slower than the MSE criterion.
574574

575+
.. _tree_missing_value_support:
576+
577+
Missing Values Support
578+
======================
579+
580+
:class:`~tree.DecisionTreeClassifier` and :class:`~tree.DecisionTreeRegressor`
581+
have built-in support for missing values when `splitter='best'` and criterion is
582+
`'gini'`, `'entropy`', or `'log_loss'`, for classification or
583+
`'squared_error'`, `'friedman_mse'`, or `'poisson'` for regression.
584+
585+
For each potential threshold on the non-missing data, the splitter will evaluate
586+
the split with all the missing values going to the left node or the right node.
587+
588+
Decisions are made as follows:
589+
590+
- By default when predicting, the samples with missing values are classified
591+
with the class used in the split found during training::
592+
593+
>>> from sklearn.tree import DecisionTreeClassifier
594+
>>> import numpy as np
595+
596+
>>> X = np.array([0, 1, 6, np.nan]).reshape(-1, 1)
597+
>>> y = [0, 0, 1, 1]
598+
599+
>>> tree = DecisionTreeClassifier(random_state=0).fit(X, y)
600+
>>> tree.predict(X)
601+
array([0, 0, 1, 1])
602+
603+
- If the the criterion evaluation is the same for both nodes,
604+
then the tie for missing value at predict time is broken by going to the
605+
right node. The splitter also checks the split where all the missing
606+
values go to one child and non-missing values go to the other::
607+
608+
>>> from sklearn.tree import DecisionTreeClassifier
609+
>>> import numpy as np
610+
611+
>>> X = np.array([np.nan, -1, np.nan, 1]).reshape(-1, 1)
612+
>>> y = [0, 0, 1, 1]
613+
614+
>>> tree = DecisionTreeClassifier(random_state=0).fit(X, y)
615+
616+
>>> X_test = np.array([np.nan]).reshape(-1, 1)
617+
>>> tree.predict(X_test)
618+
array([1])
619+
620+
- If no missing values are seen during training for a given feature, then during
621+
prediction missing values are mapped to the child with the most samples::
622+
623+
>>> from sklearn.tree import DecisionTreeClassifier
624+
>>> import numpy as np
625+
626+
>>> X = np.array([0, 1, 2, 3]).reshape(-1, 1)
627+
>>> y = [0, 1, 1, 1]
628+
629+
>>> tree = DecisionTreeClassifier(random_state=0).fit(X, y)
630+
631+
>>> X_test = np.array([np.nan]).reshape(-1, 1)
632+
>>> tree.predict(X_test)
633+
array([1])
575634

576635
.. _minimal_cost_complexity_pruning:
577636

‎doc/whats_new/v1.3.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.3.rst
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,12 @@ Changelog
486486
:mod:`sklearn.tree`
487487
...................
488488

489+
- |MajorFeature| :class:`tree.DecisionTreeRegressor` and
490+
:class:`tree.DecisionTreeClassifier` support missing values when
491+
`splitter='best'` and criterion is `gini`, `entropy`, or `log_loss`,
492+
for classification or `squared_error`, `friedman_mse`, or `poisson`
493+
for regression. :pr:`23595` by `Thomas Fan`_.
494+
489495
- |Enhancement| Adds a `class_names` parameter to
490496
:func:`tree.export_text`. This allows specifying the parameter `class_names`
491497
for each target class in ascending numerical order.

‎sklearn/tree/_classes.py

Copy file name to clipboardExpand all lines: sklearn/tree/_classes.py
+85-7Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from ..utils import Bunch
3535
from ..utils import check_random_state
3636
from ..utils.validation import _check_sample_weight
37+
from ..utils.validation import assert_all_finite
38+
from ..utils.validation import _assert_all_finite_element_wise
3739
from ..utils import compute_sample_weight
3840
from ..utils.multiclass import check_classification_targets
3941
from ..utils.validation import check_is_fitted
@@ -48,6 +50,7 @@
4850
from ._tree import _build_pruned_tree_ccp
4951
from ._tree import ccp_pruning_path
5052
from . import _tree, _splitter, _criterion
53+
from ._utils import _any_isnan_axis0
5154

5255
__all__ = [
5356
"DecisionTreeClassifier",
@@ -174,19 +177,67 @@ def get_n_leaves(self):
174177
check_is_fitted(self)
175178
return self.tree_.n_leaves
176179

177-
def fit(self, X, y, sample_weight=None, check_input=True):
180+
def _support_missing_values(self, X):
181+
return not issparse(X) and self._get_tags()["allow_nan"]
182+
183+
def _compute_feature_has_missing(self, X):
184+
"""Return boolean mask denoting if there are missing values for each feature.
185+
186+
This method also ensures that X is finite.
187+
188+
Parameter
189+
---------
190+
X : array-like of shape (n_samples, n_features), dtype=DOUBLE
191+
Input data.
192+
193+
Returns
194+
-------
195+
feature_has_missing : ndarray of shape (n_features,), or None
196+
Missing value mask. If missing values are not supported or there
197+
are no missing values, return None.
198+
"""
199+
common_kwargs = dict(estimator_name=self.__class__.__name__, input_name="X")
200+
201+
if not self._support_missing_values(X):
202+
assert_all_finite(X, **common_kwargs)
203+
return None
204+
205+
with np.errstate(over="ignore"):
206+
overall_sum = np.sum(X)
207+
208+
if not np.isfinite(overall_sum):
209+
# Raise a ValueError in case of the presence of an infinite element.
210+
_assert_all_finite_element_wise(X, xp=np, allow_nan=True, **common_kwargs)
211+
212+
# If the sum is not nan, then there are no missing values
213+
if not np.isnan(overall_sum):
214+
return None
215+
216+
feature_has_missing = _any_isnan_axis0(X)
217+
return feature_has_missing
218+
219+
def _fit(
220+
self, X, y, sample_weight=None, check_input=True, feature_has_missing=None
221+
):
178222
self._validate_params()
179223
random_state = check_random_state(self.random_state)
180224

181225
if check_input:
182226
# Need to validate separately here.
183227
# We can't pass multi_output=True because that would allow y to be
184228
# csr.
185-
check_X_params = dict(dtype=DTYPE, accept_sparse="csc")
229+
230+
# _compute_feature_has_missing will check for finite values and
231+
# compute the missing mask if the tree supports missing values
232+
check_X_params = dict(
233+
dtype=DTYPE, accept_sparse="csc", force_all_finite=False
234+
)
186235
check_y_params = dict(ensure_2d=False, dtype=None)
187236
X, y = self._validate_data(
188237
X, y, validate_separately=(check_X_params, check_y_params)
189238
)
239+
240+
feature_has_missing = self._compute_feature_has_missing(X)
190241
if issparse(X):
191242
X.sort_indices()
192243

@@ -381,7 +432,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
381432
self.min_impurity_decrease,
382433
)
383434

384-
builder.build(self.tree_, X, y, sample_weight)
435+
builder.build(self.tree_, X, y, sample_weight, feature_has_missing)
385436

386437
if self.n_outputs_ == 1 and is_classifier(self):
387438
self.n_classes_ = self.n_classes_[0]
@@ -394,7 +445,17 @@ def fit(self, X, y, sample_weight=None, check_input=True):
394445
def _validate_X_predict(self, X, check_input):
395446
"""Validate the training data on predict (probabilities)."""
396447
if check_input:
397-
X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False)
448+
if self._support_missing_values(X):
449+
force_all_finite = "allow-nan"
450+
else:
451+
force_all_finite = True
452+
X = self._validate_data(
453+
X,
454+
dtype=DTYPE,
455+
accept_sparse="csr",
456+
reset=False,
457+
force_all_finite=force_all_finite,
458+
)
398459
if issparse(X) and (
399460
X.indices.dtype != np.intc or X.indptr.dtype != np.intc
400461
):
@@ -886,7 +947,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
886947
Fitted estimator.
887948
"""
888949

889-
super().fit(
950+
super()._fit(
890951
X,
891952
y,
892953
sample_weight=sample_weight,
@@ -971,7 +1032,14 @@ def predict_log_proba(self, X):
9711032
return proba
9721033

9731034
def _more_tags(self):
974-
return {"multilabel": True}
1035+
# XXX: nan is only support for dense arrays, but we set this for common test to
1036+
# pass, specifically: check_estimators_nan_inf
1037+
allow_nan = self.splitter == "best" and self.criterion in {
1038+
"gini",
1039+
"log_loss",
1040+
"entropy",
1041+
}
1042+
return {"multilabel": True, "allow_nan": allow_nan}
9751043

9761044

9771045
class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
@@ -1239,7 +1307,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
12391307
Fitted estimator.
12401308
"""
12411309

1242-
super().fit(
1310+
super()._fit(
12431311
X,
12441312
y,
12451313
sample_weight=sample_weight,
@@ -1274,6 +1342,16 @@ def _compute_partial_dependence_recursion(self, grid, target_features):
12741342
)
12751343
return averaged_predictions
12761344

1345+
def _more_tags(self):
1346+
# XXX: nan is only support for dense arrays, but we set this for common test to
1347+
# pass, specifically: check_estimators_nan_inf
1348+
allow_nan = self.splitter == "best" and self.criterion in {
1349+
"squared_error",
1350+
"friedman_mse",
1351+
"poisson",
1352+
}
1353+
return {"allow_nan": allow_nan}
1354+
12771355

12781356
class ExtraTreeClassifier(DecisionTreeClassifier):
12791357
"""An extremely randomized tree classifier.

‎sklearn/tree/_criterion.pxd

Copy file name to clipboardExpand all lines: sklearn/tree/_criterion.pxd
+13-6Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ cdef class Criterion:
2828
cdef SIZE_t start # samples[start:pos] are the samples in the left node
2929
cdef SIZE_t pos # samples[pos:end] are the samples in the right node
3030
cdef SIZE_t end
31+
cdef SIZE_t n_missing # Number of missing values for the feature being evaluated
32+
cdef bint missing_go_to_left # Whether missing values go to the left node
3133

3234
cdef SIZE_t n_outputs # Number of outputs
3335
cdef SIZE_t n_samples # Number of samples
@@ -36,6 +38,7 @@ cdef class Criterion:
3638
cdef double weighted_n_node_samples # Weighted number of samples in the node
3739
cdef double weighted_n_left # Weighted number of samples in the left node
3840
cdef double weighted_n_right # Weighted number of samples in the right node
41+
cdef double weighted_n_missing # Weighted number of samples that are missing
3942

4043
# The criterion object is maintained such that left and right collected
4144
# statistics correspond to samples[start:pos] and samples[pos:end].
@@ -50,6 +53,8 @@ cdef class Criterion:
5053
SIZE_t start,
5154
SIZE_t end
5255
) except -1 nogil
56+
cdef void init_sum_missing(self)
57+
cdef void init_missing(self, SIZE_t n_missing) noexcept nogil
5358
cdef int reset(self) except -1 nogil
5459
cdef int reverse_reset(self) except -1 nogil
5560
cdef int update(self, SIZE_t new_pos) except -1 nogil
@@ -77,15 +82,17 @@ cdef class ClassificationCriterion(Criterion):
7782
cdef SIZE_t[::1] n_classes
7883
cdef SIZE_t max_n_classes
7984

80-
cdef double[:, ::1] sum_total # The sum of the weighted count of each label.
81-
cdef double[:, ::1] sum_left # Same as above, but for the left side of the split
82-
cdef double[:, ::1] sum_right # Same as above, but for the right side of the split
85+
cdef double[:, ::1] sum_total # The sum of the weighted count of each label.
86+
cdef double[:, ::1] sum_left # Same as above, but for the left side of the split
87+
cdef double[:, ::1] sum_right # Same as above, but for the right side of the split
88+
cdef double[:, ::1] sum_missing # Same as above, but for missing values in X
8389

8490
cdef class RegressionCriterion(Criterion):
8591
"""Abstract regression criterion."""
8692

8793
cdef double sq_sum_total
8894

89-
cdef double[::1] sum_total # The sum of w*y.
90-
cdef double[::1] sum_left # Same as above, but for the left side of the split
91-
cdef double[::1] sum_right # Same as above, but for the right side of the split
95+
cdef double[::1] sum_total # The sum of w*y.
96+
cdef double[::1] sum_left # Same as above, but for the left side of the split
97+
cdef double[::1] sum_right # Same as above, but for the right side of the split
98+
cdef double[::1] sum_missing # Same as above, but for missing values in X

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.