Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 43cf7d4

Browse filesBrowse files
authored
BUG Fixes sample weights when there are missing values in DecisionTrees (#26376)
1 parent 4a5f954 commit 43cf7d4
Copy full SHA for 43cf7d4

File tree

3 files changed

+43
-6
lines changed
Filter options

3 files changed

+43
-6
lines changed

‎doc/whats_new/v1.3.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.3.rst
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ Changelog
516516
:class:`tree.DecisionTreeClassifier` support missing values when
517517
`splitter='best'` and criterion is `gini`, `entropy`, or `log_loss`,
518518
for classification or `squared_error`, `friedman_mse`, or `poisson`
519-
for regression. :pr:`23595` by `Thomas Fan`_.
519+
for regression. :pr:`23595`, :pr:`26376` by `Thomas Fan`_.
520520

521521
- |Enhancement| Adds a `class_names` parameter to
522522
:func:`tree.export_text`. This allows specifying the parameter `class_names`

‎sklearn/tree/_criterion.pyx

Copy file name to clipboardExpand all lines: sklearn/tree/_criterion.pyx
+6-2Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,9 @@ cdef class RegressionCriterion(Criterion):
838838
self.sample_indices[-n_missing:]
839839
"""
840840
cdef SIZE_t i, p, k
841-
cdef DOUBLE_t w = 0.0
841+
cdef DOUBLE_t y_ik
842+
cdef DOUBLE_t w_y_ik
843+
cdef DOUBLE_t w = 1.0
842844

843845
self.n_missing = n_missing
844846
if n_missing == 0:
@@ -855,7 +857,9 @@ cdef class RegressionCriterion(Criterion):
855857
w = self.sample_weight[i]
856858

857859
for k in range(self.n_outputs):
858-
self.sum_missing[k] += w
860+
y_ik = self.y[i, k]
861+
w_y_ik = w * y_ik
862+
self.sum_missing[k] += w_y_ik
859863

860864
self.weighted_n_missing += w
861865

‎sklearn/tree/tests/test_tree.py

Copy file name to clipboardExpand all lines: sklearn/tree/tests/test_tree.py
+36-3Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2549,7 +2549,8 @@ def test_missing_values_poisson():
25492549
(datasets.make_classification, DecisionTreeClassifier),
25502550
],
25512551
)
2552-
def test_missing_values_is_resilience(make_data, Tree):
2552+
@pytest.mark.parametrize("sample_weight_train", [None, "ones"])
2553+
def test_missing_values_is_resilience(make_data, Tree, sample_weight_train):
25532554
"""Check that trees can deal with missing values and have decent performance."""
25542555

25552556
rng = np.random.RandomState(0)
@@ -2563,15 +2564,18 @@ def test_missing_values_is_resilience(make_data, Tree):
25632564
X_missing, y, random_state=0
25642565
)
25652566

2567+
if sample_weight_train == "ones":
2568+
sample_weight_train = np.ones(X_missing_train.shape[0])
2569+
25662570
# Train tree with missing values
25672571
tree_with_missing = Tree(random_state=rng)
2568-
tree_with_missing.fit(X_missing_train, y_train)
2572+
tree_with_missing.fit(X_missing_train, y_train, sample_weight=sample_weight_train)
25692573
score_with_missing = tree_with_missing.score(X_missing_test, y_test)
25702574

25712575
# Train tree without missing values
25722576
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
25732577
tree = Tree(random_state=rng)
2574-
tree.fit(X_train, y_train)
2578+
tree.fit(X_train, y_train, sample_weight=sample_weight_train)
25752579
score_without_missing = tree.score(X_test, y_test)
25762580

25772581
# Score is still 90 percent of the tree's score that had no missing values
@@ -2601,3 +2605,32 @@ def test_missing_value_is_predictive():
26012605

26022606
assert tree.score(X_train, y_train) >= 0.85
26032607
assert tree.score(X_test, y_test) >= 0.85
2608+
2609+
2610+
@pytest.mark.parametrize(
2611+
"make_data, Tree",
2612+
[
2613+
(datasets.make_regression, DecisionTreeRegressor),
2614+
(datasets.make_classification, DecisionTreeClassifier),
2615+
],
2616+
)
2617+
def test_sample_weight_non_uniform(make_data, Tree):
2618+
"""Check sample weight is correctly handled with missing values."""
2619+
rng = np.random.RandomState(0)
2620+
n_samples, n_features = 1000, 10
2621+
X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng)
2622+
2623+
# Create dataset with missing values
2624+
X[rng.choice([False, True], size=X.shape, p=[0.9, 0.1])] = np.nan
2625+
2626+
# Zero sample weight is the same as removing the sample
2627+
sample_weight = np.ones(X.shape[0])
2628+
sample_weight[::2] = 0.0
2629+
2630+
tree_with_sw = Tree(random_state=0)
2631+
tree_with_sw.fit(X, y, sample_weight=sample_weight)
2632+
2633+
tree_samples_removed = Tree(random_state=0)
2634+
tree_samples_removed.fit(X[1::2, :], y[1::2])
2635+
2636+
assert_allclose(tree_samples_removed.predict(X), tree_with_sw.predict(X))

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.