diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index e62df15037b7d..c51472b6c993e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -271,6 +271,7 @@ def _check_function_param_validation( "sklearn.preprocessing.scale", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", + "sklearn.tree.export_graphviz", "sklearn.tree.export_text", "sklearn.tree.plot_tree", "sklearn.utils.gen_batches", diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 1d8e590f73ea6..e8dbe51138223 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -17,7 +17,7 @@ import numpy as np from ..utils.validation import check_is_fitted, check_array -from ..utils._param_validation import Interval, validate_params, StrOptions +from ..utils._param_validation import Interval, validate_params, StrOptions, HasMethods from ..base import is_classifier @@ -441,20 +441,6 @@ def __init__( else: self.characters = ["#", "[", "]", "<=", "\\n", '"', '"'] - # validate - if isinstance(precision, Integral): - if precision < 0: - raise ValueError( - "'precision' should be greater or equal to 0." - " Got {} instead.".format(precision) - ) - else: - raise ValueError( - "'precision' should be an integer. Got {} instead.".format( - type(precision) - ) - ) - # The depth of each node for plotting with 'leaf' option self.ranks = {"leaves": []} # The colors to render each node with @@ -739,6 +725,26 @@ def recurse(self, node, tree, ax, max_x, max_y, depth=0): ax.annotate("\n (...) \n", xy_parent, xy, **kwargs) +@validate_params( + { + "decision_tree": "no_validation", + "out_file": [str, None, HasMethods("write")], + "max_depth": [Interval(Integral, 0, None, closed="left"), None], + "feature_names": ["array-like", None], + "class_names": ["array-like", "boolean", None], + "label": [StrOptions({"all", "root", "none"})], + "filled": ["boolean"], + "leaves_parallel": ["boolean"], + "impurity": ["boolean"], + "node_ids": ["boolean"], + "proportion": ["boolean"], + "rotate": ["boolean"], + "rounded": ["boolean"], + "special_characters": ["boolean"], + "precision": [Interval(Integral, 0, None, closed="left"), None], + "fontname": [str], + } +) def export_graphviz( decision_tree, out_file=None, @@ -774,8 +780,8 @@ def export_graphviz( Parameters ---------- - decision_tree : decision tree classifier - The decision tree to be exported to GraphViz. + decision_tree : object + The decision tree estimator to be exported to GraphViz. out_file : object or str, default=None Handle or name of the output file. If ``None``, the result is diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index a37c236b23def..1dc0fd7b9d8f4 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -293,13 +293,6 @@ def test_graphviz_errors(): with pytest.raises(IndexError): export_graphviz(clf, out, class_names=[]) - # Check precision error - out = StringIO() - with pytest.raises(ValueError, match="should be greater or equal"): - export_graphviz(clf, out, precision=-1) - with pytest.raises(ValueError, match="should be an integer"): - export_graphviz(clf, out, precision="1") - def test_friedman_mse_in_graphviz(): clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)