Description
Describe the bug
Hello, while using metadata routing I encountered what seems to be a bug. I do not have enough understanding of metadata routing to determine if it is actually a bug or an incorrect use.
Below is an example where I am using a meta estimator (BaggingRegressor
) around a base estimator (DecisionTreeRegressor
). In my use case, I need to dynamically wrap the base estimator in an Adapter
to do some work before calling the fit method of the base estimator. This work is based on an extra parameter extra_param
, which I request using the set_fit_request
method. The parameter is passed sucessfully, but its type is altered from string to list on one edge case (when the string matches the number of samples of X).
Steps/Code to Reproduce
import numpy as np
import sklearn
from sklearn import base, ensemble, tree
sklearn.set_config(enable_metadata_routing=True)
class Adapter(base.BaseEstimator):
def __init__(self, wrapped_estimator):
self.wrapped_estimator = wrapped_estimator
def fit(self, X, y, extra_param: str):
# Do some work before delegating to the wrapped_estimator's fit method
print(extra_param)
assert isinstance(extra_param, str)
return self.wrapped_estimator.fit(X, y)
# Delegate other methods
def __getattr__(self, name):
return getattr(self.wrapped_estimator, name)
n, p = 10, 2
rng = np.random.default_rng(0)
x = rng.random((n, p))
y = rng.integers(0, 2, n)
estimator = tree.DecisionTreeRegressor()
adapter = Adapter(estimator)
adapter.set_fit_request(extra_param=True)
meta_estimator = ensemble.BaggingRegressor(adapter, n_estimators=1)
meta_estimator.fit(x, y, extra_param="a" * (n - 1)) # Pass
meta_estimator.fit(x, y, extra_param="a" * (n + 1)) # Pass
meta_estimator.fit(x, y, extra_param="a" * n) # Fail
Expected Results
No error is thrown. The extra_param
string parameter passed to Adapter.fit
should always be a string and thus the assertion should not fail.
Actual Results
When the string length matches the number of samples, the string becomes a list, and the assertion fails.
aaaaaaaaa
aaaaaaaaaaa
['a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a']
Traceback (most recent call last):
File "minimal.py", line 35, in <module>
meta_estimator.fit(x, y, extra_param="a" * n) # Fail
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/utils/validation.py", line 63, in inner_f
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/base.py", line 1389, in wrapper
return fit_method(estimator, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/ensemble/_bagging.py", line 389, in fit
return self._fit(X, y, max_samples=self.max_samples, **fit_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/ensemble/_bagging.py", line 532, in _fit
all_results = Parallel(
^^^^^^^^^
File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/utils/parallel.py", line 77, in __call__
return super().__call__(iterable_with_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../miniforge3/envs/test/lib/python3.12/site-packages/joblib/parallel.py", line 1918, in __call__
return output if self.return_generator else list(output)
^^^^^^^^^^^^
File ".../miniforge3/envs/test/lib/python3.12/site-packages/joblib/parallel.py", line 1847, in _get_sequential_output
res = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/utils/parallel.py", line 139, in __call__
return self.function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/ensemble/_bagging.py", line 197, in _parallel_build_estimators
estimator_fit(X_, y_, **fit_params_)
File "minimal.py", line 15, in fit
assert isinstance(extra_param, str)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
Versions
System:
python: 3.12.8 | packaged by conda-forge | (main, Dec 5 2024, 14:19:53) [Clang 18.1.8 ]
executable: /Users/alexandreperez/dev/lib/miniforge3/envs/fun-ltm1/bin/python
machine: macOS-15.2-arm64-arm-64bit
Python dependencies:
sklearn: 1.6.1
pip: 24.3.1
setuptools: 75.8.0
numpy: 1.26.4
scipy: 1.15.1
Cython: None
pandas: 2.2.3
matplotlib: 3.10.0
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 14
prefix: libopenblas
filepath: /Users/alexandreperez/dev/lib/miniforge3/envs/fun-ltm1/lib/python3.12/site-packages/numpy/.dylibs/libopenblas64_.0.dylib
version: 0.3.23.dev
threading_layer: pthreads
architecture: armv8
user_api: openmp
internal_api: openmp
num_threads: 14
prefix: libomp
filepath: /Users/alexandreperez/dev/lib/miniforge3/envs/fun-ltm1/lib/python3.12/site-packages/sklearn/.dylibs/libomp.dylib
version: None
Metadata
Metadata
Assignees
Labels
Type
Projects
Status