Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit be892f5

Browse filesBrowse files
MAINT Parameters validation for sklearn.tree.export_graphviz (#26034)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent a2b8571 commit be892f5
Copy full SHA for be892f5

File tree

Expand file treeCollapse file tree

3 files changed

+24
-24
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+24
-24
lines changed

‎sklearn/tests/test_public_functions.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_public_functions.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def _check_function_param_validation(
271271
"sklearn.preprocessing.scale",
272272
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
273273
"sklearn.svm.l1_min_c",
274+
"sklearn.tree.export_graphviz",
274275
"sklearn.tree.export_text",
275276
"sklearn.tree.plot_tree",
276277
"sklearn.utils.gen_batches",

‎sklearn/tree/_export.py

Copy file name to clipboardExpand all lines: sklearn/tree/_export.py
+23-17Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818

1919
from ..utils.validation import check_is_fitted, check_array
20-
from ..utils._param_validation import Interval, validate_params, StrOptions
20+
from ..utils._param_validation import Interval, validate_params, StrOptions, HasMethods
2121

2222
from ..base import is_classifier
2323

@@ -441,20 +441,6 @@ def __init__(
441441
else:
442442
self.characters = ["#", "[", "]", "<=", "\\n", '"', '"']
443443

444-
# validate
445-
if isinstance(precision, Integral):
446-
if precision < 0:
447-
raise ValueError(
448-
"'precision' should be greater or equal to 0."
449-
" Got {} instead.".format(precision)
450-
)
451-
else:
452-
raise ValueError(
453-
"'precision' should be an integer. Got {} instead.".format(
454-
type(precision)
455-
)
456-
)
457-
458444
# The depth of each node for plotting with 'leaf' option
459445
self.ranks = {"leaves": []}
460446
# The colors to render each node with
@@ -739,6 +725,26 @@ def recurse(self, node, tree, ax, max_x, max_y, depth=0):
739725
ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)
740726

741727

728+
@validate_params(
729+
{
730+
"decision_tree": "no_validation",
731+
"out_file": [str, None, HasMethods("write")],
732+
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
733+
"feature_names": ["array-like", None],
734+
"class_names": ["array-like", "boolean", None],
735+
"label": [StrOptions({"all", "root", "none"})],
736+
"filled": ["boolean"],
737+
"leaves_parallel": ["boolean"],
738+
"impurity": ["boolean"],
739+
"node_ids": ["boolean"],
740+
"proportion": ["boolean"],
741+
"rotate": ["boolean"],
742+
"rounded": ["boolean"],
743+
"special_characters": ["boolean"],
744+
"precision": [Interval(Integral, 0, None, closed="left"), None],
745+
"fontname": [str],
746+
}
747+
)
742748
def export_graphviz(
743749
decision_tree,
744750
out_file=None,
@@ -774,8 +780,8 @@ def export_graphviz(
774780
775781
Parameters
776782
----------
777-
decision_tree : decision tree classifier
778-
The decision tree to be exported to GraphViz.
783+
decision_tree : object
784+
The decision tree estimator to be exported to GraphViz.
779785
780786
out_file : object or str, default=None
781787
Handle or name of the output file. If ``None``, the result is

‎sklearn/tree/tests/test_export.py

Copy file name to clipboardExpand all lines: sklearn/tree/tests/test_export.py
-7Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,6 @@ def test_graphviz_errors():
293293
with pytest.raises(IndexError):
294294
export_graphviz(clf, out, class_names=[])
295295

296-
# Check precision error
297-
out = StringIO()
298-
with pytest.raises(ValueError, match="should be greater or equal"):
299-
export_graphviz(clf, out, precision=-1)
300-
with pytest.raises(ValueError, match="should be an integer"):
301-
export_graphviz(clf, out, precision="1")
302-
303296

304297
def test_friedman_mse_in_graphviz():
305298
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.