Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 48c45c5

Browse filesBrowse files
author
Samuel Brice
committed
Deep copy the Criterion instance within BaseDecisionTree.fit to prevent segfault caused by concurrent accesses.
1 parent 15c2c72 commit 48c45c5
Copy full SHA for 48c45c5

File tree

3 files changed

+33
-0
lines changed
Filter options

3 files changed

+33
-0
lines changed

‎doc/whats_new/v0.24.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v0.24.rst
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ Changelog
4848
:class:`~sklearn.semi_supervised.LabelPropagation`.
4949
:pr:`19271` by :user:`Zhaowei Wang <ThuWangzw>`.
5050

51+
:mod:`sklearn.tree`
52+
.......................
53+
54+
- |Fix| Fix a bug in `fit` of :class:`tree.BaseDecisionTree` that caused
55+
segmentation faults under certain conditions. `fit` now deep copies the
56+
`Criterion` object to prevent shared concurrent accesses.
57+
:pr:`19580` by :user:`Samuel Brice <samdbrice>` and
58+
:user:`Alex Adamson <aadamson>` and
59+
:user:`Wil Yegelwel <wyegelwel>`.
60+
5161
:mod:`sklearn.utils`
5262
....................
5363

‎sklearn/ensemble/tests/test_forest.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/tests/test_forest.py
+18Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,3 +1494,21 @@ def test_n_features_deprecation(Estimator):
14941494

14951495
with pytest.warns(FutureWarning, match="n_features_ was deprecated"):
14961496
est.n_features_
1497+
1498+
1499+
@pytest.mark.parametrize('Forest', FOREST_REGRESSORS)
1500+
def test_mse_criterion_object_segfault_smoke_test(Forest):
1501+
# This is a smoke test to ensure that passing a mutable criterion
1502+
# does not cause a segfault when fitting with concurrent threads.
1503+
# Non-regression test for:
1504+
# https://github.com/scikit-learn/scikit-learn/issues/12623
1505+
from sklearn.tree._criterion import MSE
1506+
1507+
y = y_reg.reshape(-1, 1)
1508+
n_samples, n_outputs = y.shape
1509+
mse_criterion = MSE(n_outputs, n_samples)
1510+
est = FOREST_REGRESSORS[Forest](
1511+
n_estimators=2, n_jobs=2, criterion=mse_criterion
1512+
)
1513+
1514+
est.fit(X_reg, y)

‎sklearn/tree/_classes.py

Copy file name to clipboardExpand all lines: sklearn/tree/_classes.py
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numbers
1818
import warnings
19+
import copy
1920
from abc import ABCMeta
2021
from abc import abstractmethod
2122
from math import ceil
@@ -349,6 +350,10 @@ def fit(self, X, y, sample_weight=None, check_input=True,
349350
else:
350351
criterion = CRITERIA_REG[self.criterion](self.n_outputs_,
351352
n_samples)
353+
else:
354+
# Make a deepcopy in case the criterion has mutable attributes that
355+
# might be shared and modified concurrently during parallel fitting
356+
criterion = copy.deepcopy(criterion)
352357

353358
SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
354359

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.