diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 5965cd90c1b89..5a3e1a6d6d82d 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -514,6 +514,10 @@ Changelog for each target class in ascending numerical order. :pr:`25387` by :user:`William M ` and :user:`crispinlogan `. +- |Fix| :func:`tree.export_graphviz` and :func:`tree.export_text` now accepts + `feature_names` and `class_names` as array-like rather than lists. + :pr:`26289` by :user:`Yao Xiao ` + :mod:`sklearn.utils` .................... diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 6b84bed891c18..1d8e590f73ea6 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -16,7 +16,7 @@ import numpy as np -from ..utils.validation import check_is_fitted +from ..utils.validation import check_is_fitted, check_array from ..utils._param_validation import Interval, validate_params, StrOptions from ..base import is_classifier @@ -788,11 +788,11 @@ def export_graphviz( The maximum depth of the representation. If None, the tree is fully generated. - feature_names : list of str, default=None - Names of each of the features. + feature_names : array-like of shape (n_features,), default=None + An array containing the feature names. If None, generic names will be used ("x[0]", "x[1]", ...). - class_names : list of str or bool, default=None + class_names : array-like of shape (n_classes,) or bool, default=None Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. If ``True``, shows a symbolic representation of the class name. @@ -857,6 +857,14 @@ def export_graphviz( >>> tree.export_graphviz(clf) 'digraph Tree {... """ + if feature_names is not None: + feature_names = check_array( + feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0 + ) + if class_names is not None and not isinstance(class_names, bool): + class_names = check_array( + class_names, ensure_2d=False, dtype=None, ensure_min_samples=0 + ) check_is_fitted(decision_tree) own_file = False @@ -924,8 +932,8 @@ def compute_depth_( @validate_params( { "decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor], - "feature_names": [list, None], - "class_names": [list, None], + "feature_names": ["array-like", None], + "class_names": ["array-like", None], "max_depth": [Interval(Integral, 0, None, closed="left"), None], "spacing": [Interval(Integral, 1, None, closed="left"), None], "decimals": [Interval(Integral, 0, None, closed="left"), None], @@ -953,17 +961,17 @@ def export_text( It can be an instance of DecisionTreeClassifier or DecisionTreeRegressor. - feature_names : list of str, default=None - A list of length n_features containing the feature names. + feature_names : array-like of shape (n_features,), default=None + An array containing the feature names. If None generic names will be used ("feature_0", "feature_1", ...). - class_names : list or None, default=None + class_names : array-like of shape (n_classes,), default=None Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. - if `None`, the class names are delegated to `decision_tree.classes_`; - - if a list, then `class_names` will be used as class names instead - of `decision_tree.classes_`. The length of `class_names` must match + - otherwise, `class_names` will be used as class names instead of + `decision_tree.classes_`. The length of `class_names` must match the length of `decision_tree.classes_`. .. versionadded:: 1.3 @@ -1008,6 +1016,15 @@ def export_text( | |--- petal width (cm) > 1.75 | | |--- class: 2 """ + if feature_names is not None: + feature_names = check_array( + feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0 + ) + if class_names is not None: + class_names = check_array( + class_names, ensure_2d=False, dtype=None, ensure_min_samples=0 + ) + check_is_fitted(decision_tree) tree_ = decision_tree.tree_ if is_classifier(decision_tree): @@ -1015,7 +1032,7 @@ def export_text( class_names = decision_tree.classes_ elif len(class_names) != len(decision_tree.classes_): raise ValueError( - "When `class_names` is a list, it should contain as" + "When `class_names` is an array, it should contain as" " many items as `decision_tree.classes_`. Got" f" {len(class_names)} while the tree was fitted with" f" {len(decision_tree.classes_)} classes." @@ -1037,7 +1054,7 @@ def export_text( else: value_fmt = "{}{} value: {}\n" - if feature_names: + if feature_names is not None: feature_names_ = [ feature_names[i] if i != _tree.TREE_UNDEFINED else None for i in tree_.feature diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 8cdf28b8f7130..a37c236b23def 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -4,6 +4,7 @@ from re import finditer, search from textwrap import dedent +import numpy as np from numpy.random import RandomState import pytest @@ -48,48 +49,6 @@ def test_graphviz_toy(): assert contents1 == contents2 - # Test with feature_names - contents1 = export_graphviz( - clf, feature_names=["feature0", "feature1"], out_file=None - ) - contents2 = ( - "digraph Tree {\n" - 'node [shape=box, fontname="helvetica"] ;\n' - 'edge [fontname="helvetica"] ;\n' - '0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' - 'value = [3, 3]"] ;\n' - '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' - "0 -> 1 [labeldistance=2.5, labelangle=45, " - 'headlabel="True"] ;\n' - '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' - "0 -> 2 [labeldistance=2.5, labelangle=-45, " - 'headlabel="False"] ;\n' - "}" - ) - - assert contents1 == contents2 - - # Test with class_names - contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None) - contents2 = ( - "digraph Tree {\n" - 'node [shape=box, fontname="helvetica"] ;\n' - 'edge [fontname="helvetica"] ;\n' - '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' - 'value = [3, 3]\\nclass = yes"] ;\n' - '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' - 'class = yes"] ;\n' - "0 -> 1 [labeldistance=2.5, labelangle=45, " - 'headlabel="True"] ;\n' - '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n' - 'class = no"] ;\n' - "0 -> 2 [labeldistance=2.5, labelangle=-45, " - 'headlabel="False"] ;\n' - "}" - ) - - assert contents1 == contents2 - # Test plot_options contents1 = export_graphviz( clf, @@ -249,6 +208,60 @@ def test_graphviz_toy(): ) +@pytest.mark.parametrize("constructor", [list, np.array]) +def test_graphviz_feature_class_names_array_support(constructor): + # Check that export_graphviz treats feature names + # and class names correctly and supports arrays + clf = DecisionTreeClassifier( + max_depth=3, min_samples_split=2, criterion="gini", random_state=2 + ) + clf.fit(X, y) + + # Test with feature_names + contents1 = export_graphviz( + clf, feature_names=constructor(["feature0", "feature1"]), out_file=None + ) + contents2 = ( + "digraph Tree {\n" + 'node [shape=box, fontname="helvetica"] ;\n' + 'edge [fontname="helvetica"] ;\n' + '0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' + 'value = [3, 3]"] ;\n' + '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' + "0 -> 1 [labeldistance=2.5, labelangle=45, " + 'headlabel="True"] ;\n' + '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' + "0 -> 2 [labeldistance=2.5, labelangle=-45, " + 'headlabel="False"] ;\n' + "}" + ) + + assert contents1 == contents2 + + # Test with class_names + contents1 = export_graphviz( + clf, class_names=constructor(["yes", "no"]), out_file=None + ) + contents2 = ( + "digraph Tree {\n" + 'node [shape=box, fontname="helvetica"] ;\n' + 'edge [fontname="helvetica"] ;\n' + '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' + 'value = [3, 3]\\nclass = yes"] ;\n' + '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' + 'class = yes"] ;\n' + "0 -> 1 [labeldistance=2.5, labelangle=45, " + 'headlabel="True"] ;\n' + '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n' + 'class = no"] ;\n' + "0 -> 2 [labeldistance=2.5, labelangle=-45, " + 'headlabel="False"] ;\n' + "}" + ) + + assert contents1 == contents2 + + def test_graphviz_errors(): # Check for errors of export_graphviz clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2) @@ -352,7 +365,7 @@ def test_export_text_errors(): with pytest.raises(ValueError, match=err_msg): export_text(clf, feature_names=["a"]) err_msg = ( - "When `class_names` is a list, it should contain as" + "When `class_names` is an array, it should contain as" " many items as `decision_tree.classes_`. Got 1 while" " the tree was fitted with 2 classes." ) @@ -377,22 +390,6 @@ def test_export_text(): # testing that the rest of the tree is truncated assert export_text(clf, max_depth=10) == expected_report - expected_report = dedent(""" - |--- b <= 0.00 - | |--- class: -1 - |--- b > 0.00 - | |--- class: 1 - """).lstrip() - assert export_text(clf, feature_names=["a", "b"]) == expected_report - - expected_report = dedent(""" - |--- feature_1 <= 0.00 - | |--- class: cat - |--- feature_1 > 0.00 - | |--- class: dog - """).lstrip() - assert export_text(clf, class_names=["cat", "dog"]) == expected_report - expected_report = dedent(""" |--- feature_1 <= 0.00 | |--- weights: [3.00, 0.00] class: -1 @@ -453,6 +450,30 @@ def test_export_text(): ) +@pytest.mark.parametrize("constructor", [list, np.array]) +def test_export_text_feature_class_names_array_support(constructor): + # Check that export_graphviz treats feature names + # and class names correctly and supports arrays + clf = DecisionTreeClassifier(max_depth=2, random_state=0) + clf.fit(X, y) + + expected_report = dedent(""" + |--- b <= 0.00 + | |--- class: -1 + |--- b > 0.00 + | |--- class: 1 + """).lstrip() + assert export_text(clf, feature_names=constructor(["a", "b"])) == expected_report + + expected_report = dedent(""" + |--- feature_1 <= 0.00 + | |--- class: cat + |--- feature_1 > 0.00 + | |--- class: dog + """).lstrip() + assert export_text(clf, class_names=constructor(["cat", "dog"])) == expected_report + + def test_plot_tree_entropy(pyplot): # mostly smoke tests # Check correctness of export_graphviz for criterion = entropy