Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 2c97116

Browse filesBrowse files
VeghitItay
authored andcommitted
MAINT validate_params for plot_tree (#25882)
Co-authored-by: Itay <itayvegh@gmail.com>
1 parent 0d64914 commit 2c97116
Copy full SHA for 2c97116

File tree

2 files changed

+19
-15
lines changed
Filter options

2 files changed

+19
-15
lines changed

‎sklearn/tests/test_public_functions.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_public_functions.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def _check_function_param_validation(
194194
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
195195
"sklearn.svm.l1_min_c",
196196
"sklearn.tree.export_text",
197+
"sklearn.tree.plot_tree",
197198
"sklearn.utils.gen_batches",
198199
]
199200

‎sklearn/tree/_export.py

Copy file name to clipboardExpand all lines: sklearn/tree/_export.py
+18-15Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818

1919
from ..utils.validation import check_is_fitted
20-
from ..utils._param_validation import Interval, validate_params
20+
from ..utils._param_validation import Interval, validate_params, StrOptions
2121

2222
from ..base import is_classifier
2323

@@ -77,6 +77,23 @@ def __repr__(self):
7777
SENTINEL = Sentinel()
7878

7979

80+
@validate_params(
81+
{
82+
"decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
83+
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
84+
"feature_names": [list, None],
85+
"class_names": [list, None],
86+
"label": [StrOptions({"all", "root", "none"})],
87+
"filled": ["boolean"],
88+
"impurity": ["boolean"],
89+
"node_ids": ["boolean"],
90+
"proportion": ["boolean"],
91+
"rounded": ["boolean"],
92+
"precision": [Interval(Integral, 0, None, closed="left"), None],
93+
"ax": "no_validation", # delegate validation to matplotlib
94+
"fontsize": [Interval(Integral, 0, None, closed="left"), None],
95+
}
96+
)
8097
def plot_tree(
8198
decision_tree,
8299
*,
@@ -601,20 +618,6 @@ def __init__(
601618
)
602619
self.fontsize = fontsize
603620

604-
# validate
605-
if isinstance(precision, Integral):
606-
if precision < 0:
607-
raise ValueError(
608-
"'precision' should be greater or equal to 0."
609-
" Got {} instead.".format(precision)
610-
)
611-
else:
612-
raise ValueError(
613-
"'precision' should be an integer. Got {} instead.".format(
614-
type(precision)
615-
)
616-
)
617-
618621
# The depth of each node for plotting with 'leaf' option
619622
self.ranks = {"leaves": []}
620623
# The colors to render each node with

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.