Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 54108d9

Browse filesBrowse files
MAINT Parameter validation for tree.export_text (#25867)
1 parent 9260f51 commit 54108d9
Copy full SHA for 54108d9

File tree

3 files changed

+15
-20
lines changed
Filter options

3 files changed

+15
-20
lines changed

‎sklearn/tests/test_public_functions.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_public_functions.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def _check_function_param_validation(
191191
"sklearn.model_selection.train_test_split",
192192
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
193193
"sklearn.svm.l1_min_c",
194+
"sklearn.tree.export_text",
194195
"sklearn.utils.gen_batches",
195196
]
196197

‎sklearn/tree/_export.py

Copy file name to clipboardExpand all lines: sklearn/tree/_export.py
+14-10Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
import numpy as np
1818

1919
from ..utils.validation import check_is_fitted
20+
from ..utils._param_validation import Interval, validate_params
21+
2022
from ..base import is_classifier
2123

2224
from . import _criterion
2325
from . import _tree
2426
from ._reingold_tilford import buchheim, Tree
25-
from . import DecisionTreeClassifier
27+
from . import DecisionTreeClassifier, DecisionTreeRegressor
2628

2729

2830
def _color_brew(n):
@@ -919,6 +921,17 @@ def compute_depth_(
919921
return max(depths)
920922

921923

924+
@validate_params(
925+
{
926+
"decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
927+
"feature_names": [list, None],
928+
"class_names": [list, None],
929+
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
930+
"spacing": [Interval(Integral, 1, None, closed="left"), None],
931+
"decimals": [Interval(Integral, 0, None, closed="left"), None],
932+
"show_weights": ["boolean"],
933+
}
934+
)
922935
def export_text(
923936
decision_tree,
924937
*,
@@ -1011,21 +1024,12 @@ def export_text(
10111024
left_child_fmt = "{} {} > {}\n"
10121025
truncation_fmt = "{} {}\n"
10131026

1014-
if max_depth < 0:
1015-
raise ValueError("max_depth bust be >= 0, given %d" % max_depth)
1016-
10171027
if feature_names is not None and len(feature_names) != tree_.n_features:
10181028
raise ValueError(
10191029
"feature_names must contain %d elements, got %d"
10201030
% (tree_.n_features, len(feature_names))
10211031
)
10221032

1023-
if spacing <= 0:
1024-
raise ValueError("spacing must be > 0, given %d" % spacing)
1025-
1026-
if decimals < 0:
1027-
raise ValueError("decimals must be >= 0, given %d" % decimals)
1028-
10291033
if isinstance(decision_tree, DecisionTreeClassifier):
10301034
value_fmt = "{}{} weights: {}\n"
10311035
if not show_weights:

‎sklearn/tree/tests/test_export.py

Copy file name to clipboardExpand all lines: sklearn/tree/tests/test_export.py
-10Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,6 @@ def test_precision():
350350
def test_export_text_errors():
351351
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
352352
clf.fit(X, y)
353-
354-
err_msg = "max_depth bust be >= 0, given -1"
355-
with pytest.raises(ValueError, match=err_msg):
356-
export_text(clf, max_depth=-1)
357353
err_msg = "feature_names must contain 2 elements, got 1"
358354
with pytest.raises(ValueError, match=err_msg):
359355
export_text(clf, feature_names=["a"])
@@ -364,12 +360,6 @@ def test_export_text_errors():
364360
)
365361
with pytest.raises(ValueError, match=err_msg):
366362
export_text(clf, class_names=["a"])
367-
err_msg = "decimals must be >= 0, given -1"
368-
with pytest.raises(ValueError, match=err_msg):
369-
export_text(clf, decimals=-1)
370-
err_msg = "spacing must be > 0, given 0"
371-
with pytest.raises(ValueError, match=err_msg):
372-
export_text(clf, spacing=0)
373363

374364

375365
def test_export_text():

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.