Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Edge case bug in metadata routing #30739

Copy link
Copy link
Open
@aperezlebel

Description

@aperezlebel
Issue body actions

Describe the bug

Hello, while using metadata routing I encountered what seems to be a bug. I do not have enough understanding of metadata routing to determine if it is actually a bug or an incorrect use.

Below is an example where I am using a meta estimator (BaggingRegressor) around a base estimator (DecisionTreeRegressor). In my use case, I need to dynamically wrap the base estimator in an Adapter to do some work before calling the fit method of the base estimator. This work is based on an extra parameter extra_param, which I request using the set_fit_request method. The parameter is passed sucessfully, but its type is altered from string to list on one edge case (when the string matches the number of samples of X).

Steps/Code to Reproduce

import numpy as np
import sklearn
from sklearn import base, ensemble, tree

sklearn.set_config(enable_metadata_routing=True)


class Adapter(base.BaseEstimator):
    def __init__(self, wrapped_estimator):
        self.wrapped_estimator = wrapped_estimator

    def fit(self, X, y, extra_param: str):
        # Do some work before delegating to the wrapped_estimator's fit method
        print(extra_param)
        assert isinstance(extra_param, str)
        return self.wrapped_estimator.fit(X, y)

    # Delegate other methods
    def __getattr__(self, name):
        return getattr(self.wrapped_estimator, name)


n, p = 10, 2
rng = np.random.default_rng(0)
x = rng.random((n, p))
y = rng.integers(0, 2, n)

estimator = tree.DecisionTreeRegressor()
adapter = Adapter(estimator)
adapter.set_fit_request(extra_param=True)
meta_estimator = ensemble.BaggingRegressor(adapter, n_estimators=1)

meta_estimator.fit(x, y, extra_param="a" * (n - 1))  # Pass
meta_estimator.fit(x, y, extra_param="a" * (n + 1))  # Pass
meta_estimator.fit(x, y, extra_param="a" * n)  # Fail

Expected Results

No error is thrown. The extra_param string parameter passed to Adapter.fit should always be a string and thus the assertion should not fail.

Actual Results

When the string length matches the number of samples, the string becomes a list, and the assertion fails.

aaaaaaaaa
aaaaaaaaaaa
['a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a']

Traceback (most recent call last):
  File "minimal.py", line 35, in <module>
    meta_estimator.fit(x, y, extra_param="a" * n)  # Fail
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/utils/validation.py", line 63, in inner_f
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/base.py", line 1389, in wrapper
    return fit_method(estimator, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/ensemble/_bagging.py", line 389, in fit
    return self._fit(X, y, max_samples=self.max_samples, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/ensemble/_bagging.py", line 532, in _fit
    all_results = Parallel(
                  ^^^^^^^^^
  File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/utils/parallel.py", line 77, in __call__
    return super().__call__(iterable_with_config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../miniforge3/envs/test/lib/python3.12/site-packages/joblib/parallel.py", line 1918, in __call__
    return output if self.return_generator else list(output)
                                                ^^^^^^^^^^^^
  File ".../miniforge3/envs/test/lib/python3.12/site-packages/joblib/parallel.py", line 1847, in _get_sequential_output
    res = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/utils/parallel.py", line 139, in __call__
    return self.function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../miniforge3/envs/test/lib/python3.12/site-packages/sklearn/ensemble/_bagging.py", line 197, in _parallel_build_estimators
    estimator_fit(X_, y_, **fit_params_)
  File "minimal.py", line 15, in fit
    assert isinstance(extra_param, str)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Versions

System:
    python: 3.12.8 | packaged by conda-forge | (main, Dec  5 2024, 14:19:53) [Clang 18.1.8 ]
executable: /Users/alexandreperez/dev/lib/miniforge3/envs/fun-ltm1/bin/python
   machine: macOS-15.2-arm64-arm-64bit

Python dependencies:
      sklearn: 1.6.1
          pip: 24.3.1
   setuptools: 75.8.0
        numpy: 1.26.4
        scipy: 1.15.1
       Cython: None
       pandas: 2.2.3
   matplotlib: 3.10.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 14
         prefix: libopenblas
       filepath: /Users/alexandreperez/dev/lib/miniforge3/envs/fun-ltm1/lib/python3.12/site-packages/numpy/.dylibs/libopenblas64_.0.dylib
        version: 0.3.23.dev
threading_layer: pthreads
   architecture: armv8

       user_api: openmp
   internal_api: openmp
    num_threads: 14
         prefix: libomp
       filepath: /Users/alexandreperez/dev/lib/miniforge3/envs/fun-ltm1/lib/python3.12/site-packages/sklearn/.dylibs/libomp.dylib
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    BugDocumentationMetadata Routingall issues related to metadata routing, slep006, sample propsall issues related to metadata routing, slep006, sample propswontfixWe won't be fixing this issueWe won't be fixing this issue

    Type

    No type

    Projects

    Status

    No status
    Show more project fields

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Morty Proxy This is a proxified and sanitized view of the page, visit original site.