34
34
from ..utils import Bunch
35
35
from ..utils import check_random_state
36
36
from ..utils .validation import _check_sample_weight
37
+ from ..utils .validation import assert_all_finite
38
+ from ..utils .validation import _assert_all_finite_element_wise
37
39
from ..utils import compute_sample_weight
38
40
from ..utils .multiclass import check_classification_targets
39
41
from ..utils .validation import check_is_fitted
48
50
from ._tree import _build_pruned_tree_ccp
49
51
from ._tree import ccp_pruning_path
50
52
from . import _tree , _splitter , _criterion
53
+ from ._utils import _any_isnan_axis0
51
54
52
55
__all__ = [
53
56
"DecisionTreeClassifier" ,
@@ -174,19 +177,67 @@ def get_n_leaves(self):
174
177
check_is_fitted (self )
175
178
return self .tree_ .n_leaves
176
179
177
- def fit (self , X , y , sample_weight = None , check_input = True ):
180
+ def _support_missing_values (self , X ):
181
+ return not issparse (X ) and self ._get_tags ()["allow_nan" ]
182
+
183
+ def _compute_feature_has_missing (self , X ):
184
+ """Return boolean mask denoting if there are missing values for each feature.
185
+
186
+ This method also ensures that X is finite.
187
+
188
+ Parameter
189
+ ---------
190
+ X : array-like of shape (n_samples, n_features), dtype=DOUBLE
191
+ Input data.
192
+
193
+ Returns
194
+ -------
195
+ feature_has_missing : ndarray of shape (n_features,), or None
196
+ Missing value mask. If missing values are not supported or there
197
+ are no missing values, return None.
198
+ """
199
+ common_kwargs = dict (estimator_name = self .__class__ .__name__ , input_name = "X" )
200
+
201
+ if not self ._support_missing_values (X ):
202
+ assert_all_finite (X , ** common_kwargs )
203
+ return None
204
+
205
+ with np .errstate (over = "ignore" ):
206
+ overall_sum = np .sum (X )
207
+
208
+ if not np .isfinite (overall_sum ):
209
+ # Raise a ValueError in case of the presence of an infinite element.
210
+ _assert_all_finite_element_wise (X , xp = np , allow_nan = True , ** common_kwargs )
211
+
212
+ # If the sum is not nan, then there are no missing values
213
+ if not np .isnan (overall_sum ):
214
+ return None
215
+
216
+ feature_has_missing = _any_isnan_axis0 (X )
217
+ return feature_has_missing
218
+
219
+ def _fit (
220
+ self , X , y , sample_weight = None , check_input = True , feature_has_missing = None
221
+ ):
178
222
self ._validate_params ()
179
223
random_state = check_random_state (self .random_state )
180
224
181
225
if check_input :
182
226
# Need to validate separately here.
183
227
# We can't pass multi_output=True because that would allow y to be
184
228
# csr.
185
- check_X_params = dict (dtype = DTYPE , accept_sparse = "csc" )
229
+
230
+ # _compute_feature_has_missing will check for finite values and
231
+ # compute the missing mask if the tree supports missing values
232
+ check_X_params = dict (
233
+ dtype = DTYPE , accept_sparse = "csc" , force_all_finite = False
234
+ )
186
235
check_y_params = dict (ensure_2d = False , dtype = None )
187
236
X , y = self ._validate_data (
188
237
X , y , validate_separately = (check_X_params , check_y_params )
189
238
)
239
+
240
+ feature_has_missing = self ._compute_feature_has_missing (X )
190
241
if issparse (X ):
191
242
X .sort_indices ()
192
243
@@ -381,7 +432,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
381
432
self .min_impurity_decrease ,
382
433
)
383
434
384
- builder .build (self .tree_ , X , y , sample_weight )
435
+ builder .build (self .tree_ , X , y , sample_weight , feature_has_missing )
385
436
386
437
if self .n_outputs_ == 1 and is_classifier (self ):
387
438
self .n_classes_ = self .n_classes_ [0 ]
@@ -394,7 +445,17 @@ def fit(self, X, y, sample_weight=None, check_input=True):
394
445
def _validate_X_predict (self , X , check_input ):
395
446
"""Validate the training data on predict (probabilities)."""
396
447
if check_input :
397
- X = self ._validate_data (X , dtype = DTYPE , accept_sparse = "csr" , reset = False )
448
+ if self ._support_missing_values (X ):
449
+ force_all_finite = "allow-nan"
450
+ else :
451
+ force_all_finite = True
452
+ X = self ._validate_data (
453
+ X ,
454
+ dtype = DTYPE ,
455
+ accept_sparse = "csr" ,
456
+ reset = False ,
457
+ force_all_finite = force_all_finite ,
458
+ )
398
459
if issparse (X ) and (
399
460
X .indices .dtype != np .intc or X .indptr .dtype != np .intc
400
461
):
@@ -886,7 +947,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
886
947
Fitted estimator.
887
948
"""
888
949
889
- super ().fit (
950
+ super ()._fit (
890
951
X ,
891
952
y ,
892
953
sample_weight = sample_weight ,
@@ -971,7 +1032,14 @@ def predict_log_proba(self, X):
971
1032
return proba
972
1033
973
1034
def _more_tags (self ):
974
- return {"multilabel" : True }
1035
+ # XXX: nan is only support for dense arrays, but we set this for common test to
1036
+ # pass, specifically: check_estimators_nan_inf
1037
+ allow_nan = self .splitter == "best" and self .criterion in {
1038
+ "gini" ,
1039
+ "log_loss" ,
1040
+ "entropy" ,
1041
+ }
1042
+ return {"multilabel" : True , "allow_nan" : allow_nan }
975
1043
976
1044
977
1045
class DecisionTreeRegressor (RegressorMixin , BaseDecisionTree ):
@@ -1239,7 +1307,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
1239
1307
Fitted estimator.
1240
1308
"""
1241
1309
1242
- super ().fit (
1310
+ super ()._fit (
1243
1311
X ,
1244
1312
y ,
1245
1313
sample_weight = sample_weight ,
@@ -1274,6 +1342,16 @@ def _compute_partial_dependence_recursion(self, grid, target_features):
1274
1342
)
1275
1343
return averaged_predictions
1276
1344
1345
+ def _more_tags (self ):
1346
+ # XXX: nan is only support for dense arrays, but we set this for common test to
1347
+ # pass, specifically: check_estimators_nan_inf
1348
+ allow_nan = self .splitter == "best" and self .criterion in {
1349
+ "squared_error" ,
1350
+ "friedman_mse" ,
1351
+ "poisson" ,
1352
+ }
1353
+ return {"allow_nan" : allow_nan }
1354
+
1277
1355
1278
1356
class ExtraTreeClassifier (DecisionTreeClassifier ):
1279
1357
"""An extremely randomized tree classifier.
0 commit comments