Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6be774b

Browse filesBrowse files
authored
FIX tree.export_text and tree.export_graphviz accepts feature and class names as array-like (#26289)
1 parent 5782cd7 commit 6be774b
Copy full SHA for 6be774b

File tree

Expand file treeCollapse file tree

3 files changed

+114
-72
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+114
-72
lines changed

‎doc/whats_new/v1.3.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.3.rst
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,10 @@ Changelog
523523
for each target class in ascending numerical order.
524524
:pr:`25387` by :user:`William M <Akbeeh>` and :user:`crispinlogan <crispinlogan>`.
525525

526+
- |Fix| :func:`tree.export_graphviz` and :func:`tree.export_text` now accepts
527+
`feature_names` and `class_names` as array-like rather than lists.
528+
:pr:`26289` by :user:`Yao Xiao <Charlie-XIAO>`
529+
526530
:mod:`sklearn.utils`
527531
....................
528532

‎sklearn/tree/_export.py

Copy file name to clipboardExpand all lines: sklearn/tree/_export.py
+30-13Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import numpy as np
1818

19-
from ..utils.validation import check_is_fitted
19+
from ..utils.validation import check_is_fitted, check_array
2020
from ..utils._param_validation import Interval, validate_params, StrOptions
2121

2222
from ..base import is_classifier
@@ -788,11 +788,11 @@ def export_graphviz(
788788
The maximum depth of the representation. If None, the tree is fully
789789
generated.
790790
791-
feature_names : list of str, default=None
792-
Names of each of the features.
791+
feature_names : array-like of shape (n_features,), default=None
792+
An array containing the feature names.
793793
If None, generic names will be used ("x[0]", "x[1]", ...).
794794
795-
class_names : list of str or bool, default=None
795+
class_names : array-like of shape (n_classes,) or bool, default=None
796796
Names of each of the target classes in ascending numerical order.
797797
Only relevant for classification and not supported for multi-output.
798798
If ``True``, shows a symbolic representation of the class name.
@@ -857,6 +857,14 @@ def export_graphviz(
857857
>>> tree.export_graphviz(clf)
858858
'digraph Tree {...
859859
"""
860+
if feature_names is not None:
861+
feature_names = check_array(
862+
feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
863+
)
864+
if class_names is not None and not isinstance(class_names, bool):
865+
class_names = check_array(
866+
class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
867+
)
860868

861869
check_is_fitted(decision_tree)
862870
own_file = False
@@ -924,8 +932,8 @@ def compute_depth_(
924932
@validate_params(
925933
{
926934
"decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
927-
"feature_names": [list, None],
928-
"class_names": [list, None],
935+
"feature_names": ["array-like", None],
936+
"class_names": ["array-like", None],
929937
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
930938
"spacing": [Interval(Integral, 1, None, closed="left"), None],
931939
"decimals": [Interval(Integral, 0, None, closed="left"), None],
@@ -953,17 +961,17 @@ def export_text(
953961
It can be an instance of
954962
DecisionTreeClassifier or DecisionTreeRegressor.
955963
956-
feature_names : list of str, default=None
957-
A list of length n_features containing the feature names.
964+
feature_names : array-like of shape (n_features,), default=None
965+
An array containing the feature names.
958966
If None generic names will be used ("feature_0", "feature_1", ...).
959967
960-
class_names : list or None, default=None
968+
class_names : array-like of shape (n_classes,), default=None
961969
Names of each of the target classes in ascending numerical order.
962970
Only relevant for classification and not supported for multi-output.
963971
964972
- if `None`, the class names are delegated to `decision_tree.classes_`;
965-
- if a list, then `class_names` will be used as class names instead
966-
of `decision_tree.classes_`. The length of `class_names` must match
973+
- otherwise, `class_names` will be used as class names instead of
974+
`decision_tree.classes_`. The length of `class_names` must match
967975
the length of `decision_tree.classes_`.
968976
969977
.. versionadded:: 1.3
@@ -1008,14 +1016,23 @@ def export_text(
10081016
| |--- petal width (cm) > 1.75
10091017
| | |--- class: 2
10101018
"""
1019+
if feature_names is not None:
1020+
feature_names = check_array(
1021+
feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
1022+
)
1023+
if class_names is not None:
1024+
class_names = check_array(
1025+
class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
1026+
)
1027+
10111028
check_is_fitted(decision_tree)
10121029
tree_ = decision_tree.tree_
10131030
if is_classifier(decision_tree):
10141031
if class_names is None:
10151032
class_names = decision_tree.classes_
10161033
elif len(class_names) != len(decision_tree.classes_):
10171034
raise ValueError(
1018-
"When `class_names` is a list, it should contain as"
1035+
"When `class_names` is an array, it should contain as"
10191036
" many items as `decision_tree.classes_`. Got"
10201037
f" {len(class_names)} while the tree was fitted with"
10211038
f" {len(decision_tree.classes_)} classes."
@@ -1037,7 +1054,7 @@ def export_text(
10371054
else:
10381055
value_fmt = "{}{} value: {}\n"
10391056

1040-
if feature_names:
1057+
if feature_names is not None:
10411058
feature_names_ = [
10421059
feature_names[i] if i != _tree.TREE_UNDEFINED else None
10431060
for i in tree_.feature

‎sklearn/tree/tests/test_export.py

Copy file name to clipboardExpand all lines: sklearn/tree/tests/test_export.py
+80-59Lines changed: 80 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from re import finditer, search
55
from textwrap import dedent
66

7+
import numpy as np
78
from numpy.random import RandomState
89
import pytest
910

@@ -48,48 +49,6 @@ def test_graphviz_toy():
4849

4950
assert contents1 == contents2
5051

51-
# Test with feature_names
52-
contents1 = export_graphviz(
53-
clf, feature_names=["feature0", "feature1"], out_file=None
54-
)
55-
contents2 = (
56-
"digraph Tree {\n"
57-
'node [shape=box, fontname="helvetica"] ;\n'
58-
'edge [fontname="helvetica"] ;\n'
59-
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
60-
'value = [3, 3]"] ;\n'
61-
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
62-
"0 -> 1 [labeldistance=2.5, labelangle=45, "
63-
'headlabel="True"] ;\n'
64-
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
65-
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
66-
'headlabel="False"] ;\n'
67-
"}"
68-
)
69-
70-
assert contents1 == contents2
71-
72-
# Test with class_names
73-
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
74-
contents2 = (
75-
"digraph Tree {\n"
76-
'node [shape=box, fontname="helvetica"] ;\n'
77-
'edge [fontname="helvetica"] ;\n'
78-
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
79-
'value = [3, 3]\\nclass = yes"] ;\n'
80-
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
81-
'class = yes"] ;\n'
82-
"0 -> 1 [labeldistance=2.5, labelangle=45, "
83-
'headlabel="True"] ;\n'
84-
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
85-
'class = no"] ;\n'
86-
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
87-
'headlabel="False"] ;\n'
88-
"}"
89-
)
90-
91-
assert contents1 == contents2
92-
9352
# Test plot_options
9453
contents1 = export_graphviz(
9554
clf,
@@ -249,6 +208,60 @@ def test_graphviz_toy():
249208
)
250209

251210

211+
@pytest.mark.parametrize("constructor", [list, np.array])
212+
def test_graphviz_feature_class_names_array_support(constructor):
213+
# Check that export_graphviz treats feature names
214+
# and class names correctly and supports arrays
215+
clf = DecisionTreeClassifier(
216+
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
217+
)
218+
clf.fit(X, y)
219+
220+
# Test with feature_names
221+
contents1 = export_graphviz(
222+
clf, feature_names=constructor(["feature0", "feature1"]), out_file=None
223+
)
224+
contents2 = (
225+
"digraph Tree {\n"
226+
'node [shape=box, fontname="helvetica"] ;\n'
227+
'edge [fontname="helvetica"] ;\n'
228+
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
229+
'value = [3, 3]"] ;\n'
230+
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
231+
"0 -> 1 [labeldistance=2.5, labelangle=45, "
232+
'headlabel="True"] ;\n'
233+
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
234+
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
235+
'headlabel="False"] ;\n'
236+
"}"
237+
)
238+
239+
assert contents1 == contents2
240+
241+
# Test with class_names
242+
contents1 = export_graphviz(
243+
clf, class_names=constructor(["yes", "no"]), out_file=None
244+
)
245+
contents2 = (
246+
"digraph Tree {\n"
247+
'node [shape=box, fontname="helvetica"] ;\n'
248+
'edge [fontname="helvetica"] ;\n'
249+
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
250+
'value = [3, 3]\\nclass = yes"] ;\n'
251+
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
252+
'class = yes"] ;\n'
253+
"0 -> 1 [labeldistance=2.5, labelangle=45, "
254+
'headlabel="True"] ;\n'
255+
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
256+
'class = no"] ;\n'
257+
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
258+
'headlabel="False"] ;\n'
259+
"}"
260+
)
261+
262+
assert contents1 == contents2
263+
264+
252265
def test_graphviz_errors():
253266
# Check for errors of export_graphviz
254267
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
@@ -352,7 +365,7 @@ def test_export_text_errors():
352365
with pytest.raises(ValueError, match=err_msg):
353366
export_text(clf, feature_names=["a"])
354367
err_msg = (
355-
"When `class_names` is a list, it should contain as"
368+
"When `class_names` is an array, it should contain as"
356369
" many items as `decision_tree.classes_`. Got 1 while"
357370
" the tree was fitted with 2 classes."
358371
)
@@ -377,22 +390,6 @@ def test_export_text():
377390
# testing that the rest of the tree is truncated
378391
assert export_text(clf, max_depth=10) == expected_report
379392

380-
expected_report = dedent("""
381-
|--- b <= 0.00
382-
| |--- class: -1
383-
|--- b > 0.00
384-
| |--- class: 1
385-
""").lstrip()
386-
assert export_text(clf, feature_names=["a", "b"]) == expected_report
387-
388-
expected_report = dedent("""
389-
|--- feature_1 <= 0.00
390-
| |--- class: cat
391-
|--- feature_1 > 0.00
392-
| |--- class: dog
393-
""").lstrip()
394-
assert export_text(clf, class_names=["cat", "dog"]) == expected_report
395-
396393
expected_report = dedent("""
397394
|--- feature_1 <= 0.00
398395
| |--- weights: [3.00, 0.00] class: -1
@@ -453,6 +450,30 @@ def test_export_text():
453450
)
454451

455452

453+
@pytest.mark.parametrize("constructor", [list, np.array])
454+
def test_export_text_feature_class_names_array_support(constructor):
455+
# Check that export_graphviz treats feature names
456+
# and class names correctly and supports arrays
457+
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
458+
clf.fit(X, y)
459+
460+
expected_report = dedent("""
461+
|--- b <= 0.00
462+
| |--- class: -1
463+
|--- b > 0.00
464+
| |--- class: 1
465+
""").lstrip()
466+
assert export_text(clf, feature_names=constructor(["a", "b"])) == expected_report
467+
468+
expected_report = dedent("""
469+
|--- feature_1 <= 0.00
470+
| |--- class: cat
471+
|--- feature_1 > 0.00
472+
| |--- class: dog
473+
""").lstrip()
474+
assert export_text(clf, class_names=constructor(["cat", "dog"])) == expected_report
475+
476+
456477
def test_plot_tree_entropy(pyplot):
457478
# mostly smoke tests
458479
# Check correctness of export_graphviz for criterion = entropy

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.