4
4
from re import finditer , search
5
5
from textwrap import dedent
6
6
7
+ import numpy as np
7
8
from numpy .random import RandomState
8
9
import pytest
9
10
@@ -48,48 +49,6 @@ def test_graphviz_toy():
48
49
49
50
assert contents1 == contents2
50
51
51
- # Test with feature_names
52
- contents1 = export_graphviz (
53
- clf , feature_names = ["feature0" , "feature1" ], out_file = None
54
- )
55
- contents2 = (
56
- "digraph Tree {\n "
57
- 'node [shape=box, fontname="helvetica"] ;\n '
58
- 'edge [fontname="helvetica"] ;\n '
59
- '0 [label="feature0 <= 0.0\\ ngini = 0.5\\ nsamples = 6\\ n'
60
- 'value = [3, 3]"] ;\n '
61
- '1 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [3, 0]"] ;\n '
62
- "0 -> 1 [labeldistance=2.5, labelangle=45, "
63
- 'headlabel="True"] ;\n '
64
- '2 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [0, 3]"] ;\n '
65
- "0 -> 2 [labeldistance=2.5, labelangle=-45, "
66
- 'headlabel="False"] ;\n '
67
- "}"
68
- )
69
-
70
- assert contents1 == contents2
71
-
72
- # Test with class_names
73
- contents1 = export_graphviz (clf , class_names = ["yes" , "no" ], out_file = None )
74
- contents2 = (
75
- "digraph Tree {\n "
76
- 'node [shape=box, fontname="helvetica"] ;\n '
77
- 'edge [fontname="helvetica"] ;\n '
78
- '0 [label="x[0] <= 0.0\\ ngini = 0.5\\ nsamples = 6\\ n'
79
- 'value = [3, 3]\\ nclass = yes"] ;\n '
80
- '1 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [3, 0]\\ n'
81
- 'class = yes"] ;\n '
82
- "0 -> 1 [labeldistance=2.5, labelangle=45, "
83
- 'headlabel="True"] ;\n '
84
- '2 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [0, 3]\\ n'
85
- 'class = no"] ;\n '
86
- "0 -> 2 [labeldistance=2.5, labelangle=-45, "
87
- 'headlabel="False"] ;\n '
88
- "}"
89
- )
90
-
91
- assert contents1 == contents2
92
-
93
52
# Test plot_options
94
53
contents1 = export_graphviz (
95
54
clf ,
@@ -249,6 +208,60 @@ def test_graphviz_toy():
249
208
)
250
209
251
210
211
+ @pytest .mark .parametrize ("constructor" , [list , np .array ])
212
+ def test_graphviz_feature_class_names_array_support (constructor ):
213
+ # Check that export_graphviz treats feature names
214
+ # and class names correctly and supports arrays
215
+ clf = DecisionTreeClassifier (
216
+ max_depth = 3 , min_samples_split = 2 , criterion = "gini" , random_state = 2
217
+ )
218
+ clf .fit (X , y )
219
+
220
+ # Test with feature_names
221
+ contents1 = export_graphviz (
222
+ clf , feature_names = constructor (["feature0" , "feature1" ]), out_file = None
223
+ )
224
+ contents2 = (
225
+ "digraph Tree {\n "
226
+ 'node [shape=box, fontname="helvetica"] ;\n '
227
+ 'edge [fontname="helvetica"] ;\n '
228
+ '0 [label="feature0 <= 0.0\\ ngini = 0.5\\ nsamples = 6\\ n'
229
+ 'value = [3, 3]"] ;\n '
230
+ '1 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [3, 0]"] ;\n '
231
+ "0 -> 1 [labeldistance=2.5, labelangle=45, "
232
+ 'headlabel="True"] ;\n '
233
+ '2 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [0, 3]"] ;\n '
234
+ "0 -> 2 [labeldistance=2.5, labelangle=-45, "
235
+ 'headlabel="False"] ;\n '
236
+ "}"
237
+ )
238
+
239
+ assert contents1 == contents2
240
+
241
+ # Test with class_names
242
+ contents1 = export_graphviz (
243
+ clf , class_names = constructor (["yes" , "no" ]), out_file = None
244
+ )
245
+ contents2 = (
246
+ "digraph Tree {\n "
247
+ 'node [shape=box, fontname="helvetica"] ;\n '
248
+ 'edge [fontname="helvetica"] ;\n '
249
+ '0 [label="x[0] <= 0.0\\ ngini = 0.5\\ nsamples = 6\\ n'
250
+ 'value = [3, 3]\\ nclass = yes"] ;\n '
251
+ '1 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [3, 0]\\ n'
252
+ 'class = yes"] ;\n '
253
+ "0 -> 1 [labeldistance=2.5, labelangle=45, "
254
+ 'headlabel="True"] ;\n '
255
+ '2 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [0, 3]\\ n'
256
+ 'class = no"] ;\n '
257
+ "0 -> 2 [labeldistance=2.5, labelangle=-45, "
258
+ 'headlabel="False"] ;\n '
259
+ "}"
260
+ )
261
+
262
+ assert contents1 == contents2
263
+
264
+
252
265
def test_graphviz_errors ():
253
266
# Check for errors of export_graphviz
254
267
clf = DecisionTreeClassifier (max_depth = 3 , min_samples_split = 2 )
@@ -352,7 +365,7 @@ def test_export_text_errors():
352
365
with pytest .raises (ValueError , match = err_msg ):
353
366
export_text (clf , feature_names = ["a" ])
354
367
err_msg = (
355
- "When `class_names` is a list , it should contain as"
368
+ "When `class_names` is an array , it should contain as"
356
369
" many items as `decision_tree.classes_`. Got 1 while"
357
370
" the tree was fitted with 2 classes."
358
371
)
@@ -377,22 +390,6 @@ def test_export_text():
377
390
# testing that the rest of the tree is truncated
378
391
assert export_text (clf , max_depth = 10 ) == expected_report
379
392
380
- expected_report = dedent ("""
381
- |--- b <= 0.00
382
- | |--- class: -1
383
- |--- b > 0.00
384
- | |--- class: 1
385
- """ ).lstrip ()
386
- assert export_text (clf , feature_names = ["a" , "b" ]) == expected_report
387
-
388
- expected_report = dedent ("""
389
- |--- feature_1 <= 0.00
390
- | |--- class: cat
391
- |--- feature_1 > 0.00
392
- | |--- class: dog
393
- """ ).lstrip ()
394
- assert export_text (clf , class_names = ["cat" , "dog" ]) == expected_report
395
-
396
393
expected_report = dedent ("""
397
394
|--- feature_1 <= 0.00
398
395
| |--- weights: [3.00, 0.00] class: -1
@@ -453,6 +450,30 @@ def test_export_text():
453
450
)
454
451
455
452
453
+ @pytest .mark .parametrize ("constructor" , [list , np .array ])
454
+ def test_export_text_feature_class_names_array_support (constructor ):
455
+ # Check that export_graphviz treats feature names
456
+ # and class names correctly and supports arrays
457
+ clf = DecisionTreeClassifier (max_depth = 2 , random_state = 0 )
458
+ clf .fit (X , y )
459
+
460
+ expected_report = dedent ("""
461
+ |--- b <= 0.00
462
+ | |--- class: -1
463
+ |--- b > 0.00
464
+ | |--- class: 1
465
+ """ ).lstrip ()
466
+ assert export_text (clf , feature_names = constructor (["a" , "b" ])) == expected_report
467
+
468
+ expected_report = dedent ("""
469
+ |--- feature_1 <= 0.00
470
+ | |--- class: cat
471
+ |--- feature_1 > 0.00
472
+ | |--- class: dog
473
+ """ ).lstrip ()
474
+ assert export_text (clf , class_names = constructor (["cat" , "dog" ])) == expected_report
475
+
476
+
456
477
def test_plot_tree_entropy (pyplot ):
457
478
# mostly smoke tests
458
479
# Check correctness of export_graphviz for criterion = entropy
0 commit comments