Closed
Description
Describe the bug
When using GridSearchCV
with a custom estimator that includes nested parameter grids, a ValueError
is raised in scikit-learn 1.5.0 indicating "entry not a 2- or 3- tuple". This issue does not occur in scikit-learn 1.4.0, where the grid search completes successfully.
Steps/Code to Reproduce
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.base import BaseEstimator, ClassifierMixin
class SimpleEstimator(BaseEstimator, ClassifierMixin):
def __init__(self, base_clf, param1=None, param2=True):
self.base_clf = base_clf
self.param1 = param1
self.param2 = param2
def fit(self, X, y=None):
# Simulate using the parameters in the fitting process
if self.param1:
pass # Simulate using param1
if self.param2:
pass # Simulate using param2
self.base_clf.fit(X, y)
return self
def predict(self, X):
return self.base_clf.predict(X)
def score(self, X, y):
return self.base_clf.score(X, y)
def test_gridsearchcv_with_custom_estimator():
param_grid = {
"param1": [None, {"option": "A"}, {"option": "B"}],
"param2": [True, False],
}
base_clf = LogisticRegression()
grid_search = GridSearchCV(
estimator=SimpleEstimator(base_clf),
param_grid=param_grid,
cv=3,
)
X_train = np.random.rand(20, 2)
y_train = np.random.randint(0, 2, 20)
grid_search.fit(X_train, y_train)
print("Best params:", grid_search.best_params_)
print("Best score:", grid_search.best_score_)
test_gridsearchcv_with_custom_estimator()
Expected Results
The GridSearchCV
should complete without any errors, exploring all combinations of the parameters specified in param_grid.
Example:
Best params: {'param1': None, 'param2': True}
Best score: 0.5
Actual Results
In scikit-learn 1.5.0, the following error occurs:
Traceback (most recent call last):
File "/workspaces/mwe.py", line 49, in <module>
test_gridsearchcv_with_custom_estimator()
File "/workspaces/mwe.py", line 45, in test_gridsearchcv_with_custom_estimator
grid_search.fit(X_train, y_train)
File "/home/vscode/.local/lib/python3.11/site-packages/sklearn/base.py", line 1473, in wrapper
return fit_method(estimator, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vscode/.local/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 968, in fit
self._run_search(evaluate_candidates)
File "/home/vscode/.local/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 1543, in _run_search
evaluate_candidates(ParameterGrid(self.param_grid))
File "/home/vscode/.local/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 962, in evaluate_candidates
results = self._format_results(
^^^^^^^^^^^^^^^^^^^^^
File "/home/vscode/.local/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 1092, in _format_results
arr_dtype = np.result_type(*param_list)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<__array_function__ internals>", line 200, in result_type
File "/home/vscode/.local/lib/python3.11/site-packages/numpy/core/_internal.py", line 61, in _usefields
names, formats, offsets, titles = _makenames_list(adict, align)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vscode/.local/lib/python3.11/site-packages/numpy/core/_internal.py", line 31, in _makenames_list
raise ValueError("entry not a 2- or 3- tuple")
ValueError: entry not a 2- or 3- tuple
In scikit-learn 1.4.0, the grid search completes successfully:
Best params: {'param1': None, 'param2': True}
Best score: 0.5
Versions
System:
python: 3.11.7 (main, Dec 19 2023, 20:33:49) [GCC 12.2.0]
executable: /usr/local/bin/python
machine: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.36
Python dependencies:
sklearn: 1.5.0
pip: 23.2.1
setuptools: 69.0.3
numpy: 1.24.3
scipy: 1.12.0
Cython: None
pandas: 2.2.0
matplotlib: 3.7.4
joblib: 1.3.2
threadpoolctl: 3.2.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 32
prefix: libopenblas
filepath: /home/vscode/.local/lib/python3.11/site-packages/numpy.libs/libopenblas64_p-r0-15028c96.3.21.so
version: 0.3.21
threading_layer: pthreads
architecture: Zen
user_api: blas
internal_api: openblas
num_threads: 32
prefix: libopenblas
filepath: /home/vscode/.local/lib/python3.11/site-packages/scipy.libs/libopenblasp-r0-23e5df77.3.21.dev.so
version: 0.3.21.dev
threading_layer: pthreads
architecture: Zen
user_api: openmp
internal_api: openmp
num_threads: 32
prefix: libgomp
filepath: /home/vscode/.local/lib/python3.11/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None