Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 66491e9

Browse filesBrowse files
author
Samuel Brice
committed
Deep copy the Criterion instance within BaseDecisionTree.fit to prevent segfault caused by concurrent accesses.
1 parent 15c2c72 commit 66491e9
Copy full SHA for 66491e9

File tree

4 files changed

+45
-0
lines changed
Filter options

4 files changed

+45
-0
lines changed

‎doc/whats_new/v0.24.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v0.24.rst
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ Changelog
5454
- |Fix| Better contains the CSS provided by :func:`utils.estimator_html_repr`
5555
by giving CSS ids to the html representation. :pr:`19417` by `Thomas Fan`_.
5656

57+
:mod:`sklearn.tree`
58+
.......................
59+
60+
- |Fix| Fix a bug in `fit` of :class:`tree.BaseDecisionTree` that caused
61+
segmentation faults under certain conditions. `fit` now deep copies the
62+
`Criterion` object to prevent shared concurrent accesses.
63+
:pr:`19580` by :user:`Samuel Brice <samdbrice>` and
64+
:user:`Alex Adamson <aadamson>` and
65+
:user:`Wil Yegelwel <wyegelwel>`.
66+
5767
.. _changes_0_24_1:
5868

5969
Version 0.24.1

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,16 @@ Changelog
207207
for non-English characters. :pr:`18959` by :user:`Zero <Zeroto521>`
208208
and :user:`wstates <wstates>`.
209209

210+
:mod:`sklearn.tree`
211+
.......................
212+
213+
- |Fix| Fix a bug in `fit` of :class:`tree.BaseDecisionTree` that caused
214+
segmentation faults under certain conditions. `fit` now deep copies the
215+
`Criterion` object to prevent shared concurrent accesses.
216+
:pr:`19580` by :user:`Samuel Brice <samdbrice>` and
217+
:user:`Alex Adamson <aadamson>` and
218+
:user:`Wil Yegelwel <wyegelwel>`.
219+
210220
Code and Documentation Contributors
211221
-----------------------------------
212222

‎sklearn/ensemble/tests/test_forest.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/tests/test_forest.py
+20Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,3 +1494,23 @@ def test_n_features_deprecation(Estimator):
14941494

14951495
with pytest.warns(FutureWarning, match="n_features_ was deprecated"):
14961496
est.n_features_
1497+
1498+
1499+
@pytest.mark.parametrize('Forest', FOREST_REGRESSORS)
1500+
def test_mse_criterion_object_segfault_smoke_test(Forest):
1501+
# This is a smoke test to ensure that passing a mutable criterion
1502+
# does not cause a segfault when fitting with concurrent threads.
1503+
# Non-regression test for:
1504+
# https://github.com/scikit-learn/scikit-learn/issues/12623
1505+
from sklearn.tree._classes import CRITERIA_REG
1506+
1507+
X = np.random.random((1000, 3))
1508+
y = np.random.random((1000, 1))
1509+
1510+
n_samples, n_outputs = y.shape
1511+
mse_criterion = CRITERIA_REG['mse'](n_outputs, n_samples)
1512+
est = FOREST_REGRESSORS[Forest](
1513+
n_estimators=2, n_jobs=2, criterion=mse_criterion
1514+
)
1515+
1516+
est.fit(X, y)

‎sklearn/tree/_classes.py

Copy file name to clipboardExpand all lines: sklearn/tree/_classes.py
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numbers
1818
import warnings
19+
import copy
1920
from abc import ABCMeta
2021
from abc import abstractmethod
2122
from math import ceil
@@ -349,6 +350,10 @@ def fit(self, X, y, sample_weight=None, check_input=True,
349350
else:
350351
criterion = CRITERIA_REG[self.criterion](self.n_outputs_,
351352
n_samples)
353+
else:
354+
# Make a deepcopy in case the criterion has mutable attributes that
355+
# might be shared and modified concurrently during parallel fitting
356+
criterion = copy.deepcopy(criterion)
352357

353358
SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
354359

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.