From 309a333687639d6597567a47fadda6c21fbef643 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 27 Apr 2023 21:30:57 +0800 Subject: [PATCH 01/10] export_text accepts feature and class names an numpy array --- sklearn/tree/_export.py | 16 ++++++++-------- sklearn/tree/tests/test_export.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 6b84bed891c18..88354a005f50e 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -924,8 +924,8 @@ def compute_depth_( @validate_params( { "decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor], - "feature_names": [list, None], - "class_names": [list, None], + "feature_names": ["array-like", None], + "class_names": ["array-like", None], "max_depth": [Interval(Integral, 0, None, closed="left"), None], "spacing": [Interval(Integral, 1, None, closed="left"), None], "decimals": [Interval(Integral, 0, None, closed="left"), None], @@ -953,16 +953,16 @@ def export_text( It can be an instance of DecisionTreeClassifier or DecisionTreeRegressor. - feature_names : list of str, default=None - A list of length n_features containing the feature names. + feature_names : array-like of shape (n_features,), default=None + An array containing the feature names. If None generic names will be used ("feature_0", "feature_1", ...). - class_names : list or None, default=None + class_names : array-like or None, default=None Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. - if `None`, the class names are delegated to `decision_tree.classes_`; - - if a list, then `class_names` will be used as class names instead + - if an array, then `class_names` will be used as class names instead of `decision_tree.classes_`. The length of `class_names` must match the length of `decision_tree.classes_`. @@ -1015,7 +1015,7 @@ def export_text( class_names = decision_tree.classes_ elif len(class_names) != len(decision_tree.classes_): raise ValueError( - "When `class_names` is a list, it should contain as" + "When `class_names` is an array, it should contain as" " many items as `decision_tree.classes_`. Got" f" {len(class_names)} while the tree was fitted with" f" {len(decision_tree.classes_)} classes." @@ -1037,7 +1037,7 @@ def export_text( else: value_fmt = "{}{} value: {}\n" - if feature_names: + if feature_names is not None: feature_names_ = [ feature_names[i] if i != _tree.TREE_UNDEFINED else None for i in tree_.feature diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 8cdf28b8f7130..191b3d9891df0 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -352,7 +352,7 @@ def test_export_text_errors(): with pytest.raises(ValueError, match=err_msg): export_text(clf, feature_names=["a"]) err_msg = ( - "When `class_names` is a list, it should contain as" + "When `class_names` is an array, it should contain as" " many items as `decision_tree.classes_`. Got 1 while" " the tree was fitted with 2 classes." ) From c853c827f103aab3205b55dea5e7bc420d9636be Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 27 Apr 2023 21:48:01 +0800 Subject: [PATCH 02/10] modified docstring for export_graphviz --- sklearn/tree/_export.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 88354a005f50e..9450bf00a9afc 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -788,11 +788,11 @@ def export_graphviz( The maximum depth of the representation. If None, the tree is fully generated. - feature_names : list of str, default=None - Names of each of the features. + feature_names : array-like of shape(n_features,), default=None + An array containing the feature names. If None, generic names will be used ("x[0]", "x[1]", ...). - class_names : list of str or bool, default=None + class_names : array-like or bool, default=None Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. If ``True``, shows a symbolic representation of the class name. From cb58a4e46387500eaadc12400c6625f9f93c5ae4 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 27 Apr 2023 21:51:14 +0800 Subject: [PATCH 03/10] added changelog --- doc/whats_new/v1.3.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index bb245aa466152..6a50d0cb00d1a 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -491,6 +491,10 @@ Changelog for each target class in ascending numerical order. :pr:`25387` by :user:`William M ` and :user:`crispinlogan `. +- |Fix| :func:`tree.export_graphviz` and :func:`tree.export_text` now accepts + `feature_names` and `class_names` as array-like rather than lists. + :pr:`26289` by :user:`Yao Xiao ` + :mod:`sklearn.utils` .................... From 5e134bc52136358458af2e33bb9a4f92bcdcce8d Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 27 Apr 2023 22:44:22 +0800 Subject: [PATCH 04/10] regression tests --- sklearn/tree/tests/test_export.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 191b3d9891df0..65bfd6d2ff009 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -4,6 +4,7 @@ from re import finditer, search from textwrap import dedent +import numpy as np from numpy.random import RandomState import pytest @@ -21,6 +22,9 @@ w = [1, 1, 1, 0.5, 0.5, 0.5] y_degraded = [1, 1, 1, 1, 1, 1] +# constructors for feature names and class names +constructors = [list, np.array] + def test_graphviz_toy(): # Check correctness of export_graphviz @@ -49,9 +53,6 @@ def test_graphviz_toy(): assert contents1 == contents2 # Test with feature_names - contents1 = export_graphviz( - clf, feature_names=["feature0", "feature1"], out_file=None - ) contents2 = ( "digraph Tree {\n" 'node [shape=box, fontname="helvetica"] ;\n' @@ -67,10 +68,13 @@ def test_graphviz_toy(): "}" ) - assert contents1 == contents2 + for cons in constructors: + contents1 = export_graphviz( + clf, feature_names=cons(["feature0", "feature1"]), out_file=None + ) + assert contents1 == contents2 # Test with class_names - contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None) contents2 = ( "digraph Tree {\n" 'node [shape=box, fontname="helvetica"] ;\n' @@ -88,7 +92,9 @@ def test_graphviz_toy(): "}" ) - assert contents1 == contents2 + for cons in constructors: + contents1 = export_graphviz(clf, class_names=cons(["yes", "no"]), out_file=None) + assert contents1 == contents2 # Test plot_options contents1 = export_graphviz( @@ -383,7 +389,8 @@ def test_export_text(): |--- b > 0.00 | |--- class: 1 """).lstrip() - assert export_text(clf, feature_names=["a", "b"]) == expected_report + for cons in constructors: + assert export_text(clf, feature_names=cons(["a", "b"])) == expected_report expected_report = dedent(""" |--- feature_1 <= 0.00 @@ -391,7 +398,8 @@ def test_export_text(): |--- feature_1 > 0.00 | |--- class: dog """).lstrip() - assert export_text(clf, class_names=["cat", "dog"]) == expected_report + for cons in constructors: + assert export_text(clf, class_names=cons(["cat", "dog"])) == expected_report expected_report = dedent(""" |--- feature_1 <= 0.00 From bb2f9e492b3d4fc43a6c80fc364ded082e8ad01f Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 27 Apr 2023 23:13:10 +0800 Subject: [PATCH 05/10] array-like --> list or ndarray --- sklearn/tree/_export.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 9450bf00a9afc..bb86864e784f8 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -924,8 +924,8 @@ def compute_depth_( @validate_params( { "decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor], - "feature_names": ["array-like", None], - "class_names": ["array-like", None], + "feature_names": [list, np.ndarray, None], + "class_names": [list, np.ndarray, None], "max_depth": [Interval(Integral, 0, None, closed="left"), None], "spacing": [Interval(Integral, 1, None, closed="left"), None], "decimals": [Interval(Integral, 0, None, closed="left"), None], @@ -953,18 +953,18 @@ def export_text( It can be an instance of DecisionTreeClassifier or DecisionTreeRegressor. - feature_names : array-like of shape (n_features,), default=None - An array containing the feature names. + feature_names : list or ndarray of shape (n_features,), default=None + A list or an array containing the feature names. If None generic names will be used ("feature_0", "feature_1", ...). - class_names : array-like or None, default=None + class_names : list, ndarray, or None, default=None Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. - if `None`, the class names are delegated to `decision_tree.classes_`; - - if an array, then `class_names` will be used as class names instead - of `decision_tree.classes_`. The length of `class_names` must match - the length of `decision_tree.classes_`. + - if a list or an array, then `class_names` will be used as class names + instead of `decision_tree.classes_`. The length of `class_names` must + match the length of `decision_tree.classes_`. .. versionadded:: 1.3 From f9a91e09e54fc47f30eeac8ca0af9746fa74e7f2 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Sat, 29 Apr 2023 00:15:10 +0800 Subject: [PATCH 06/10] apply check_array to support all array-like --- sklearn/tree/_export.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index bb86864e784f8..d208aa52562a4 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -16,7 +16,7 @@ import numpy as np -from ..utils.validation import check_is_fitted +from ..utils.validation import check_is_fitted, check_array from ..utils._param_validation import Interval, validate_params, StrOptions from ..base import is_classifier @@ -788,7 +788,7 @@ def export_graphviz( The maximum depth of the representation. If None, the tree is fully generated. - feature_names : array-like of shape(n_features,), default=None + feature_names : array-like of shape (n_features,), default=None An array containing the feature names. If None, generic names will be used ("x[0]", "x[1]", ...). @@ -857,6 +857,14 @@ def export_graphviz( >>> tree.export_graphviz(clf) 'digraph Tree {... """ + if feature_names is not None: + feature_names = check_array( + feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0 + ) + if class_names is not None: + class_names = check_array( + class_names, ensure_2d=False, dtype=None, ensure_min_samples=0 + ) check_is_fitted(decision_tree) own_file = False @@ -924,8 +932,8 @@ def compute_depth_( @validate_params( { "decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor], - "feature_names": [list, np.ndarray, None], - "class_names": [list, np.ndarray, None], + "feature_names": ["array-like", None], + "class_names": ["array-like", None], "max_depth": [Interval(Integral, 0, None, closed="left"), None], "spacing": [Interval(Integral, 1, None, closed="left"), None], "decimals": [Interval(Integral, 0, None, closed="left"), None], @@ -953,8 +961,8 @@ def export_text( It can be an instance of DecisionTreeClassifier or DecisionTreeRegressor. - feature_names : list or ndarray of shape (n_features,), default=None - A list or an array containing the feature names. + feature_names : array-like of shape (n_features,), default=None + An array containing the feature names. If None generic names will be used ("feature_0", "feature_1", ...). class_names : list, ndarray, or None, default=None @@ -962,9 +970,9 @@ def export_text( Only relevant for classification and not supported for multi-output. - if `None`, the class names are delegated to `decision_tree.classes_`; - - if a list or an array, then `class_names` will be used as class names - instead of `decision_tree.classes_`. The length of `class_names` must - match the length of `decision_tree.classes_`. + - otherwise, `class_names` will be used as class names instead of + `decision_tree.classes_`. The length of `class_names` must match + the length of `decision_tree.classes_`. .. versionadded:: 1.3 @@ -1008,6 +1016,15 @@ def export_text( | |--- petal width (cm) > 1.75 | | |--- class: 2 """ + if feature_names is not None: + feature_names = check_array( + feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0 + ) + if class_names is not None: + class_names = check_array( + class_names, ensure_2d=False, dtype=None, ensure_min_samples=0 + ) + check_is_fitted(decision_tree) tree_ = decision_tree.tree_ if is_classifier(decision_tree): From ff795c195886a05c4c1a475d16897151ca474427 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Sat, 29 Apr 2023 03:09:28 +0800 Subject: [PATCH 07/10] fixed graphviz can take bool --- sklearn/tree/_export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index d208aa52562a4..5264101e3eb0d 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -861,7 +861,7 @@ def export_graphviz( feature_names = check_array( feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0 ) - if class_names is not None: + if class_names is not None and not isinstance(class_names, bool): class_names = check_array( class_names, ensure_2d=False, dtype=None, ensure_min_samples=0 ) From 664bf08fd4d85279251748060ab96115d2d5b7d9 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 4 May 2023 22:48:32 +0800 Subject: [PATCH 08/10] modified docstrings --- sklearn/tree/_export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 5264101e3eb0d..1d8e590f73ea6 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -792,7 +792,7 @@ def export_graphviz( An array containing the feature names. If None, generic names will be used ("x[0]", "x[1]", ...). - class_names : array-like or bool, default=None + class_names : array-like of shape (n_classes,) or bool, default=None Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. If ``True``, shows a symbolic representation of the class name. @@ -965,7 +965,7 @@ def export_text( An array containing the feature names. If None generic names will be used ("feature_0", "feature_1", ...). - class_names : list, ndarray, or None, default=None + class_names : array-like of shape (n_classes,), default=None Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. From fac4e7003e7054836ab513cab77d3a36da04dc69 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 4 May 2023 22:49:42 +0800 Subject: [PATCH 09/10] decoupled tests --- sklearn/tree/tests/test_export.py | 97 ++++++++++++++++++++++++++----- 1 file changed, 82 insertions(+), 15 deletions(-) diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 65bfd6d2ff009..001e63d631edc 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -22,9 +22,6 @@ w = [1, 1, 1, 0.5, 0.5, 0.5] y_degraded = [1, 1, 1, 1, 1, 1] -# constructors for feature names and class names -constructors = [list, np.array] - def test_graphviz_toy(): # Check correctness of export_graphviz @@ -53,6 +50,9 @@ def test_graphviz_toy(): assert contents1 == contents2 # Test with feature_names + contents1 = export_graphviz( + clf, feature_names=["feature0", "feature1"], out_file=None + ) contents2 = ( "digraph Tree {\n" 'node [shape=box, fontname="helvetica"] ;\n' @@ -68,13 +68,10 @@ def test_graphviz_toy(): "}" ) - for cons in constructors: - contents1 = export_graphviz( - clf, feature_names=cons(["feature0", "feature1"]), out_file=None - ) - assert contents1 == contents2 + assert contents1 == contents2 # Test with class_names + contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None) contents2 = ( "digraph Tree {\n" 'node [shape=box, fontname="helvetica"] ;\n' @@ -92,9 +89,7 @@ def test_graphviz_toy(): "}" ) - for cons in constructors: - contents1 = export_graphviz(clf, class_names=cons(["yes", "no"]), out_file=None) - assert contents1 == contents2 + assert contents1 == contents2 # Test plot_options contents1 = export_graphviz( @@ -255,6 +250,57 @@ def test_graphviz_toy(): ) +def test_graphviz_array_support(): + # Check that export_graphviz supports feature names + # and class names as arrays + clf = DecisionTreeClassifier( + max_depth=3, min_samples_split=2, criterion="gini", random_state=2 + ) + clf.fit(X, y) + + # Test with feature_names + contents1 = export_graphviz( + clf, feature_names=np.array(["feature0", "feature1"]), out_file=None + ) + contents2 = ( + "digraph Tree {\n" + 'node [shape=box, fontname="helvetica"] ;\n' + 'edge [fontname="helvetica"] ;\n' + '0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' + 'value = [3, 3]"] ;\n' + '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' + "0 -> 1 [labeldistance=2.5, labelangle=45, " + 'headlabel="True"] ;\n' + '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' + "0 -> 2 [labeldistance=2.5, labelangle=-45, " + 'headlabel="False"] ;\n' + "}" + ) + + assert contents1 == contents2 + + # Test with class_names + contents1 = export_graphviz(clf, class_names=np.array(["yes", "no"]), out_file=None) + contents2 = ( + "digraph Tree {\n" + 'node [shape=box, fontname="helvetica"] ;\n' + 'edge [fontname="helvetica"] ;\n' + '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' + 'value = [3, 3]\\nclass = yes"] ;\n' + '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' + 'class = yes"] ;\n' + "0 -> 1 [labeldistance=2.5, labelangle=45, " + 'headlabel="True"] ;\n' + '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n' + 'class = no"] ;\n' + "0 -> 2 [labeldistance=2.5, labelangle=-45, " + 'headlabel="False"] ;\n' + "}" + ) + + assert contents1 == contents2 + + def test_graphviz_errors(): # Check for errors of export_graphviz clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2) @@ -389,8 +435,7 @@ def test_export_text(): |--- b > 0.00 | |--- class: 1 """).lstrip() - for cons in constructors: - assert export_text(clf, feature_names=cons(["a", "b"])) == expected_report + assert export_text(clf, feature_names=["a", "b"]) == expected_report expected_report = dedent(""" |--- feature_1 <= 0.00 @@ -398,8 +443,7 @@ def test_export_text(): |--- feature_1 > 0.00 | |--- class: dog """).lstrip() - for cons in constructors: - assert export_text(clf, class_names=cons(["cat", "dog"])) == expected_report + assert export_text(clf, class_names=["cat", "dog"]) == expected_report expected_report = dedent(""" |--- feature_1 <= 0.00 @@ -461,6 +505,29 @@ def test_export_text(): ) +def test_export_text_graphviz_support(): + # Check that export_text supports feature names + # and class names as arrays + clf = DecisionTreeClassifier(max_depth=2, random_state=0) + clf.fit(X, y) + + expected_report = dedent(""" + |--- b <= 0.00 + | |--- class: -1 + |--- b > 0.00 + | |--- class: 1 + """).lstrip() + assert export_text(clf, feature_names=np.array(["a", "b"])) == expected_report + + expected_report = dedent(""" + |--- feature_1 <= 0.00 + | |--- class: cat + |--- feature_1 > 0.00 + | |--- class: dog + """).lstrip() + assert export_text(clf, class_names=np.array(["cat", "dog"])) == expected_report + + def test_plot_tree_entropy(pyplot): # mostly smoke tests # Check correctness of export_graphviz for criterion = entropy From 58ce3d9065849a2ed3a2b93f1f496f1e503b959f Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Fri, 5 May 2023 21:39:53 +0800 Subject: [PATCH 10/10] parametrized tests for array support, removed original tests only with list --- sklearn/tree/tests/test_export.py | 82 ++++++------------------------- 1 file changed, 14 insertions(+), 68 deletions(-) diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 001e63d631edc..a37c236b23def 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -49,48 +49,6 @@ def test_graphviz_toy(): assert contents1 == contents2 - # Test with feature_names - contents1 = export_graphviz( - clf, feature_names=["feature0", "feature1"], out_file=None - ) - contents2 = ( - "digraph Tree {\n" - 'node [shape=box, fontname="helvetica"] ;\n' - 'edge [fontname="helvetica"] ;\n' - '0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' - 'value = [3, 3]"] ;\n' - '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' - "0 -> 1 [labeldistance=2.5, labelangle=45, " - 'headlabel="True"] ;\n' - '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' - "0 -> 2 [labeldistance=2.5, labelangle=-45, " - 'headlabel="False"] ;\n' - "}" - ) - - assert contents1 == contents2 - - # Test with class_names - contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None) - contents2 = ( - "digraph Tree {\n" - 'node [shape=box, fontname="helvetica"] ;\n' - 'edge [fontname="helvetica"] ;\n' - '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' - 'value = [3, 3]\\nclass = yes"] ;\n' - '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' - 'class = yes"] ;\n' - "0 -> 1 [labeldistance=2.5, labelangle=45, " - 'headlabel="True"] ;\n' - '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n' - 'class = no"] ;\n' - "0 -> 2 [labeldistance=2.5, labelangle=-45, " - 'headlabel="False"] ;\n' - "}" - ) - - assert contents1 == contents2 - # Test plot_options contents1 = export_graphviz( clf, @@ -250,9 +208,10 @@ def test_graphviz_toy(): ) -def test_graphviz_array_support(): - # Check that export_graphviz supports feature names - # and class names as arrays +@pytest.mark.parametrize("constructor", [list, np.array]) +def test_graphviz_feature_class_names_array_support(constructor): + # Check that export_graphviz treats feature names + # and class names correctly and supports arrays clf = DecisionTreeClassifier( max_depth=3, min_samples_split=2, criterion="gini", random_state=2 ) @@ -260,7 +219,7 @@ def test_graphviz_array_support(): # Test with feature_names contents1 = export_graphviz( - clf, feature_names=np.array(["feature0", "feature1"]), out_file=None + clf, feature_names=constructor(["feature0", "feature1"]), out_file=None ) contents2 = ( "digraph Tree {\n" @@ -280,7 +239,9 @@ def test_graphviz_array_support(): assert contents1 == contents2 # Test with class_names - contents1 = export_graphviz(clf, class_names=np.array(["yes", "no"]), out_file=None) + contents1 = export_graphviz( + clf, class_names=constructor(["yes", "no"]), out_file=None + ) contents2 = ( "digraph Tree {\n" 'node [shape=box, fontname="helvetica"] ;\n' @@ -429,22 +390,6 @@ def test_export_text(): # testing that the rest of the tree is truncated assert export_text(clf, max_depth=10) == expected_report - expected_report = dedent(""" - |--- b <= 0.00 - | |--- class: -1 - |--- b > 0.00 - | |--- class: 1 - """).lstrip() - assert export_text(clf, feature_names=["a", "b"]) == expected_report - - expected_report = dedent(""" - |--- feature_1 <= 0.00 - | |--- class: cat - |--- feature_1 > 0.00 - | |--- class: dog - """).lstrip() - assert export_text(clf, class_names=["cat", "dog"]) == expected_report - expected_report = dedent(""" |--- feature_1 <= 0.00 | |--- weights: [3.00, 0.00] class: -1 @@ -505,9 +450,10 @@ def test_export_text(): ) -def test_export_text_graphviz_support(): - # Check that export_text supports feature names - # and class names as arrays +@pytest.mark.parametrize("constructor", [list, np.array]) +def test_export_text_feature_class_names_array_support(constructor): + # Check that export_graphviz treats feature names + # and class names correctly and supports arrays clf = DecisionTreeClassifier(max_depth=2, random_state=0) clf.fit(X, y) @@ -517,7 +463,7 @@ def test_export_text_graphviz_support(): |--- b > 0.00 | |--- class: 1 """).lstrip() - assert export_text(clf, feature_names=np.array(["a", "b"])) == expected_report + assert export_text(clf, feature_names=constructor(["a", "b"])) == expected_report expected_report = dedent(""" |--- feature_1 <= 0.00 @@ -525,7 +471,7 @@ def test_export_text_graphviz_support(): |--- feature_1 > 0.00 | |--- class: dog """).lstrip() - assert export_text(clf, class_names=np.array(["cat", "dog"])) == expected_report + assert export_text(clf, class_names=constructor(["cat", "dog"])) == expected_report def test_plot_tree_entropy(pyplot):