Closed
Description
Describe the bug
The March 20, 2023 commit (MAINT validate_params for plot_tree (#25882) ) of the file:
scikit-learn/sklearn/tree/_export.py
introduced the parameter validation of the plot_tree function that does not seem to agree with the documentation in the docstring or website. The parameter validation seems to omit the bool option described in the help. This option was previously permissible. Has it been removed as a valid option or is the parameter validation missing this option?
@validate_params(
{
...
"class_names": _**[list, None]**_,
...
}
)
class_names : list of str or **_bool_**, default=None
Names of each of the target classes in ascending numerical order.
Only relevant for classification and not supported for multi-output.
If ``True``, shows a symbolic representation of the class name.
Steps/Code to Reproduce
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
SEED = 42
data = datasets.load_wine()
X = data.data
y = data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=SEED)
dt = DecisionTreeClassifier(max_depth=4, random_state=SEED)
dt.fit(X_train, y_train)
features = data.feature_names
classes = data.target_names.tolist()
plot_tree(dt, feature_names=features, class_names=classes)
plt.show()
# Works in 1.2.2, error in 1.3.0
plot_tree(dt, feature_names=features, class_names=True)
plt.show()
Expected Results
No error is thrown.
Actual Results
Traceback (most recent call last):
File ~\Anaconda3\envs\py311\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec
exec(code, globals, locals)
File c:\temp\decisiontree.py:26
plot_tree(dt, feature_names=features, class_names=True)
File ~\Anaconda3\envs\py311\Lib\site-packages\sklearn\utils\_param_validation.py:201 in wrapper
validate_parameter_constraints(
File ~\Anaconda3\envs\py311\Lib\site-packages\sklearn\utils\_param_validation.py:95 in validate_parameter_constraints
raise InvalidParameterError(
InvalidParameterError: The 'class_names' parameter of plot_tree must be an instance of 'list' or None. Got True instead.
Versions
import sklearn; sklearn.show_versions()
System:
python: 3.9.17 (main, Jul 5 2023, 21:22:06) [MSC v.1916 64 bit (AMD64)]
executable: C:\Users\lance.endres\Anaconda3\python.exe
machine: Windows-10-10.0.19044-SP0
Python dependencies:
sklearn: 1.2.2
pip: 23.2.1
setuptools: 68.0.0
numpy: 1.21.5
scipy: 1.10.1
Cython: 0.29.32
pandas: 1.5.3
matplotlib: 3.7.1
joblib: 1.2.0
threadpoolctl: 2.2.0
Built with OpenMP: True
threadpoolctl info:
filepath: C:\Users\lance.endres\Anaconda3\Library\bin\mkl_rt.1.dll
prefix: mkl_rt
user_api: blas
internal_api: mkl
version: 2021.4-Product
num_threads: 6
threading_layer: intel
filepath: C:\Users\lance.endres\Anaconda3\vcomp140.dll
prefix: vcomp
user_api: openmp
internal_api: openmp
version: None
num_threads: 12
filepath: C:\Users\lance.endres\Anaconda3\Library\bin\libiomp5md.dll
prefix: libiomp
user_api: openmp
internal_api: openmp
version: None
num_threads: 12