Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f10721e

Browse filesBrowse files
committed
add sample weight support
1 parent 5f1beed commit f10721e
Copy full SHA for f10721e

File tree

4 files changed

+37
-34
lines changed
Filter options

4 files changed

+37
-34
lines changed

‎sklearn/ensemble/_forest.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/_forest.py
+16-12Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,10 @@ def fit(self, X, y, sample_weight=None):
527527

528528
if callable(self.oob_score):
529529
self._set_oob_score_and_attributes(
530-
X, y, scoring_function=self.oob_score
530+
X, y, sample_weight, scoring_function=self.oob_score
531531
)
532532
else:
533-
self._set_oob_score_and_attributes(X, y)
533+
self._set_oob_score_and_attributes(X, y, sample_weight)
534534

535535
# Decapsulate classes_ attributes
536536
if hasattr(self, "classes_") and self.n_outputs_ == 1:
@@ -540,7 +540,7 @@ def fit(self, X, y, sample_weight=None):
540540
return self
541541

542542
@abstractmethod
543-
def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
543+
def _set_oob_score_and_attributes(self, X, y, sample_weight, scoring_function=None):
544544
"""Compute and set the OOB score and attributes.
545545
546546
Parameters
@@ -683,8 +683,11 @@ def feature_importances_(self):
683683
return all_importances / np.sum(all_importances)
684684

685685
def _compute_unbiased_feature_importance_and_oob_predictions_per_tree(
686-
self, tree, X, y, method, n_samples
686+
self, tree, X, y, sample_weight, method
687687
):
688+
n_samples = X.shape[0]
689+
if sample_weight is None:
690+
sample_weight = np.ones((n_samples,), dtype=np.float64)
688691
n_samples_bootstrap = _get_n_samples_bootstrap(
689692
n_samples,
690693
self.max_samples,
@@ -705,6 +708,7 @@ def _compute_unbiased_feature_importance_and_oob_predictions_per_tree(
705708
tree.compute_unbiased_feature_importance_and_oob_predictions(
706709
X_test=X_test,
707710
y_test=y_test,
711+
sample_weight=sample_weight,
708712
method=method,
709713
)
710714
)
@@ -713,7 +717,7 @@ def _compute_unbiased_feature_importance_and_oob_predictions_per_tree(
713717
return (importances, oob_pred, n_oob_pred)
714718

715719
def _compute_unbiased_feature_importance_and_oob_predictions(
716-
self, X, y, method="ufi"
720+
self, X, y, sample_weight, method="ufi"
717721
): # "mdi_oob"
718722
check_is_fitted(self)
719723
X = self._validate_X_predict(X)
@@ -728,7 +732,7 @@ def _compute_unbiased_feature_importance_and_oob_predictions(
728732
)(
729733
delayed(
730734
self._compute_unbiased_feature_importance_and_oob_predictions_per_tree
731-
)(tree, X, y, method, n_samples)
735+
)(tree, X, y, sample_weight, method)
732736
for tree in self.estimators_
733737
if tree.tree_.node_count > 1
734738
)
@@ -884,7 +888,7 @@ def _get_oob_predictions(tree, X):
884888
y_pred = np.rollaxis(y_pred, axis=0, start=3)
885889
return y_pred
886890

887-
def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
891+
def _set_oob_score_and_attributes(self, X, y, sample_weight, scoring_function=None):
888892
"""Compute and set the OOB score and attributes.
889893
890894
Parameters
@@ -902,12 +906,12 @@ def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
902906

903907
ufi_feature_importances, self.oob_decision_function_ = (
904908
self._compute_unbiased_feature_importance_and_oob_predictions(
905-
X, y, method="ufi"
909+
X, y, sample_weight, method="ufi"
906910
)
907911
)
908912
mdi_oob_feature_importances, _ = (
909913
self._compute_unbiased_feature_importance_and_oob_predictions(
910-
X, y, method="mdi_oob"
914+
X, y, sample_weight, method="mdi_oob"
911915
)
912916
)
913917
if self.criterion == "gini":
@@ -1230,7 +1234,7 @@ def _get_oob_predictions(tree, X):
12301234
y_pred = y_pred[:, np.newaxis, :]
12311235
return y_pred
12321236

1233-
def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
1237+
def _set_oob_score_and_attributes(self, X, y, sample_weight, scoring_function=None):
12341238
"""Compute and set the OOB score and attributes.
12351239
12361240
Parameters
@@ -1247,12 +1251,12 @@ def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
12471251

12481252
ufi_feature_importances, self.oob_prediction_ = (
12491253
self._compute_unbiased_feature_importance_and_oob_predictions(
1250-
X, y, method="ufi"
1254+
X, y, sample_weight, method="ufi"
12511255
)
12521256
)
12531257
mdi_oob_feature_importances, _ = (
12541258
self._compute_unbiased_feature_importance_and_oob_predictions(
1255-
X, y, method="mdi_oob"
1259+
X, y, sample_weight, method="mdi_oob"
12561260
)
12571261
)
12581262
if self.criterion == "squared_error":

‎sklearn/tree/_classes.py

Copy file name to clipboardExpand all lines: sklearn/tree/_classes.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,11 +691,11 @@ def feature_importances_(self):
691691
return self.tree_.compute_feature_importances()
692692

693693
def compute_unbiased_feature_importance_and_oob_predictions(
694-
self, X_test, y_test, method="ufi"
694+
self, X_test, y_test, sample_weight, method="ufi"
695695
):
696696
check_is_fitted(self)
697697
return self.tree_.compute_unbiased_feature_importance_and_oob_predictions(
698-
X_test, y_test, self.criterion, method=method
698+
X_test, y_test, sample_weight, self.criterion, method=method
699699
)
700700

701701
def __sklearn_tags__(self):

‎sklearn/tree/_tree.pxd

Copy file name to clipboardExpand all lines: sklearn/tree/_tree.pxd
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ cdef class Tree:
7878
cpdef compute_node_depths(self)
7979
cpdef compute_feature_importances(self, normalize=*)
8080

81-
cdef void _compute_oob_node_values_and_predictions(self, object X_test, intp_t[:, ::1] y_test, float64_t[:, :, ::1] oob_pred, int32_t[::1] has_oob_sample, float64_t[:, :, ::1] oob_node_values, str method)
82-
cpdef compute_unbiased_feature_importance_and_oob_predictions(self, object X_test, object y_test, criterion, method=*)
81+
cdef void _compute_oob_node_values_and_predictions(self, object X_test, intp_t[:, ::1] y_test, float64_t[::1] sample_weight, float64_t[:, :, ::1] oob_pred, int32_t[::1] has_oob_sample, float64_t[:, :, ::1] oob_node_values, str method)
82+
cpdef compute_unbiased_feature_importance_and_oob_predictions(self, object X_test, object y_test, object sample_weight, criterion, method=*)
8383
cdef float64_t mdi_oob_impurity_decrease(self, float64_t[:, :, ::1] oob_node_values, int node_idx, int left_idx, int right_idx, Node node)
8484
cdef float64_t ufi_impurity_decrease(self, float64_t[:, :, ::1] oob_node_values, int node_idx, int left_idx, int right_idx, Node node, str criterion)
8585

‎sklearn/tree/_tree.pyx

Copy file name to clipboardExpand all lines: sklearn/tree/_tree.pyx
+17-18Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,7 @@ cdef class Tree:
12751275

12761276
return np.asarray(importances)
12771277

1278-
cdef void _compute_oob_node_values_and_predictions(self, object X_test, intp_t[:, ::1] y_test, float64_t[:, :, ::1] oob_pred, int32_t[::1] has_oob_sample, float64_t[:, :, ::1] oob_node_values, str method):
1278+
cdef void _compute_oob_node_values_and_predictions(self, object X_test, intp_t[:, ::1] y_test, float64_t[::1] sample_weight, float64_t[:, :, ::1] oob_pred, int32_t[::1] has_oob_sample, float64_t[:, :, ::1] oob_node_values, str method):
12791279
if issparse(X_test):
12801280
raise(NotImplementedError("does not support sparse X yet"))
12811281
if not isinstance(X_test, np.ndarray):
@@ -1290,7 +1290,7 @@ cdef class Tree:
12901290
cdef intp_t n_outputs = self.n_outputs
12911291
cdef intp_t max_n_classes = self.max_n_classes
12921292
cdef int k, c, node_idx, sample_idx = 0
1293-
cdef int32_t[:, ::1] count_oob_values = np.zeros((node_count, n_outputs), dtype=np.int32)
1293+
cdef float64_t[:, ::1] total_oob_weight = np.zeros((node_count, n_outputs), dtype=np.float64)
12941294
cdef int node_value_idx = -1
12951295

12961296
cdef Node* node
@@ -1308,17 +1308,15 @@ cdef class Tree:
13081308
if n_classes[k] > 1:
13091309
for c in range(n_classes[k]):
13101310
if y_test[k, sample_idx] == c:
1311-
oob_node_values[node_idx, c, k] += 1.0
1312-
# TODO use sample weight instead of 1
1313-
count_oob_values[node_idx, k] += 1
1311+
oob_node_values[node_idx, c, k] += sample_weight[sample_idx]
13141312
else:
13151313
if method == "ufi":
13161314
node_value_idx = node_idx * self.value_stride + k * max_n_classes
1317-
oob_node_values[node_idx, 0, k] += (y_test[k, sample_idx] - self.value[node_value_idx]) ** 2.0
1315+
oob_node_values[node_idx, 0, k] += sample_weight[sample_idx] * (y_test[k, sample_idx] - self.value[node_value_idx]) ** 2.0
13181316
else:
1319-
oob_node_values[node_idx, 0, k] += y_test[k, sample_idx]
1320-
count_oob_values[node_idx, k] += 1
1321-
# TODO use sample weight instead of 1
1317+
oob_node_values[node_idx, 0, k] += sample_weight[sample_idx] * y_test[k, sample_idx]
1318+
total_oob_weight[node_idx, k] += sample_weight[sample_idx]
1319+
13221320
# child nodes
13231321
while node.left_child != _TREE_LEAF and node.right_child != _TREE_LEAF:
13241322
if X_ndarray[sample_idx, node.feature] <= node.threshold:
@@ -1331,26 +1329,26 @@ cdef class Tree:
13311329
if n_classes[k] > 1:
13321330
for c in range(n_classes[k]):
13331331
if y_test[k, sample_idx] == c:
1334-
oob_node_values[node_idx, c, k] += 1.0
1332+
oob_node_values[node_idx, c, k] += sample_weight[sample_idx]
13351333
# TODO use sample weight instead of 1
1336-
count_oob_values[node_idx, k] += 1
1334+
total_oob_weight[node_idx, k] += sample_weight[sample_idx]
13371335
else:
13381336
if method == "ufi":
13391337
node_value_idx = node_idx * self.value_stride + k * max_n_classes
1340-
oob_node_values[node_idx, 0, k] += (y_test[k, sample_idx] - self.value[node_value_idx]) ** 2.0
1338+
oob_node_values[node_idx, 0, k] += sample_weight[sample_idx] * (y_test[k, sample_idx] - self.value[node_value_idx]) ** 2.0
13411339
else:
1342-
oob_node_values[node_idx, 0, k] += y_test[k, sample_idx]
1343-
count_oob_values[node_idx, k] += 1
1340+
oob_node_values[node_idx, 0, k] += sample_weight[sample_idx] * y_test[k, sample_idx]
1341+
total_oob_weight[node_idx, k] += sample_weight[sample_idx]
13441342
# TODO use sample weight instead of 1
13451343
# store the id of the leaf where each sample ends up
13461344
y_leafs[sample_idx] = node_idx
13471345

13481346
# convert the counts to proportions
13491347
for node_idx in range(node_count):
13501348
for k in range(n_outputs):
1351-
if count_oob_values[node_idx, k] > 0:
1349+
if total_oob_weight[node_idx, k] > 0.0:
13521350
for c in range(n_classes[k]):
1353-
oob_node_values[node_idx, c, k] /= count_oob_values[node_idx, k]
1351+
oob_node_values[node_idx, c, k] /= total_oob_weight[node_idx, k]
13541352
# if leaf store the predictive proba
13551353
if self.nodes[node_idx].left_child == _TREE_LEAF and self.nodes[node_idx].right_child == _TREE_LEAF:
13561354
for sample_idx in range(n_samples):
@@ -1360,7 +1358,7 @@ cdef class Tree:
13601358
node_value_idx = node_idx * self.value_stride + k * max_n_classes + c
13611359
oob_pred[sample_idx, c, k] = self.value[node_value_idx]
13621360

1363-
cpdef compute_unbiased_feature_importance_and_oob_predictions(self, object X_test, object y_test, criterion, method="ufi"):
1361+
cpdef compute_unbiased_feature_importance_and_oob_predictions(self, object X_test, object y_test, object sample_weight, criterion, method="ufi"):
13641362
cdef intp_t n_samples = X_test.shape[0]
13651363
cdef intp_t n_features = X_test.shape[1]
13661364
cdef intp_t n_outputs = self.n_outputs
@@ -1378,7 +1376,8 @@ cdef class Tree:
13781376
cdef int left_idx, right_idx = -1
13791377

13801378
cdef intp_t[:, ::1] y_view = np.ascontiguousarray(y_test, dtype=np.intp)
1381-
self._compute_oob_node_values_and_predictions(X_test, y_view, oob_pred, has_oob_sample, oob_node_values, method)
1379+
cdef float64_t[::1] sample_weight_view = np.ascontiguousarray(sample_weight, dtype=np.float64)
1380+
self._compute_oob_node_values_and_predictions(X_test, y_view, sample_weight_view, oob_pred, has_oob_sample, oob_node_values, method)
13821381

13831382
for node_idx in range(self.node_count):
13841383
node = nodes[node_idx]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.