Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

SequentialFeatureSelector fails on text features even though the estimator supports them #30785

Copy link
Copy link
Open
@sktin

Description

@sktin
Issue body actions

Describe the bug

When a model can handle the data type (may it be text or NaN), SequentialFeatureSelector appears to be performing its own validation ignoring the capability of the model and apparently always insists that everything must be numbers. cross_val_score appears to be working so it's SequentialFeatureSelector that is rejecting the data.

Steps/Code to Reproduce

from sklearn.datasets import load_diabetes
from sklearn.feature_selection import SequentialFeatureSelector
from xgboost import XGBRegressor
from sklearn.model_selection import cross_val_score

import sklearn; print(F'{sklearn.__version__=}')
import xgboost; print(F'{xgboost.__version__=}')

X, y = load_diabetes(return_X_y=True, as_frame=True, scaled=False)
X['sex'] = X['sex'].apply(lambda x: 'M' if x==1.0 else 'F').astype('category')
model = XGBRegressor(enable_categorical=True, random_state=0)
print('Testing cross_val_score begins')
cross_val_score(model, X, y, error_score='raise') # no error
print('Testing cross_val_score ends')
print('Testing SequentialFeatureSelector begins')
SequentialFeatureSelector(model, tol=0).fit(X, y)
print('Testing SequentialFeatureSelector ends')

Expected Results

sklearn.__version__='1.6.1'
xgboost.__version__='2.1.4'
Testing cross_val_score begins
Testing cross_val_score ends
Testing SequentialFeatureSelector begins
Testing SequentialFeatureSelector ends

(No errors)

Actual Results

sklearn.__version__='1.6.1'
xgboost.__version__='2.1.4'
Testing cross_val_score begins
Testing cross_val_score ends
Testing SequentialFeatureSelector begins

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-29-fb1642c5f9e7> in <cell line: 16>()
     14 print('Testing cross_val_score ends')
     15 print('Testing SequentialFeatureSelector begins')
---> 16 SequentialFeatureSelector(model, tol=0).fit(X, y)
     17 print('Testing SequentialFeatureSelector ends')

/usr/local/lib/python3.10/dist-packages/sklearn/base.py in wrapper(estimator, *args, **kwargs)
   1387                 )
   1388             ):
-> 1389                 return fit_method(estimator, *args, **kwargs)
   1390 
   1391         return wrapper

/usr/local/lib/python3.10/dist-packages/sklearn/feature_selection/_sequential.py in fit(self, X, y, **params)
    280             process_routing(self, "fit", **params)
    281         for _ in range(n_iterations):
--> 282             new_feature_idx, new_score = self._get_best_new_feature_score(
    283                 cloned_estimator, X, y, cv, current_mask, **params
    284             )

/usr/local/lib/python3.10/dist-packages/sklearn/feature_selection/_sequential.py in _get_best_new_feature_score(self, estimator, X, y, cv, current_mask, **params)
    311                 candidate_mask = ~candidate_mask
    312             X_new = X[:, candidate_mask]
--> 313             scores[feature_idx] = cross_val_score(
    314                 estimator,
    315                 X_new,

/usr/local/lib/python3.10/dist-packages/sklearn/utils/_param_validation.py in wrapper(*args, **kwargs)
    214                     )
    215                 ):
--> 216                     return func(*args, **kwargs)
    217             except InvalidParameterError as e:
    218                 # When the function is just a wrapper around an estimator, we allow

/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py in cross_val_score(estimator, X, y, groups, scoring, cv, n_jobs, verbose, params, pre_dispatch, error_score)
    682     scorer = check_scoring(estimator, scoring=scoring)
    683 
--> 684     cv_results = cross_validate(
    685         estimator=estimator,
    686         X=X,

/usr/local/lib/python3.10/dist-packages/sklearn/utils/_param_validation.py in wrapper(*args, **kwargs)
    214                     )
    215                 ):
--> 216                     return func(*args, **kwargs)
    217             except InvalidParameterError as e:
    218                 # When the function is just a wrapper around an estimator, we allow

/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py in cross_validate(estimator, X, y, groups, scoring, cv, n_jobs, verbose, params, pre_dispatch, return_train_score, return_estimator, return_indices, error_score)
    429     )
    430 
--> 431     _warn_or_raise_about_fit_failures(results, error_score)
    432 
    433     # For callable scoring, the return type is only know after calling. If the

/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py in _warn_or_raise_about_fit_failures(results, error_score)
    515                 f"Below are more details about the failures:\n{fit_errors_summary}"
    516             )
--> 517             raise ValueError(all_fits_failed_message)
    518 
    519         else:

ValueError: 
All the 5 fits failed.
It is very likely that your model is misconfigured.
You can try to debug the error by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
1 fits failed with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py", line 866, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 726, in inner_f
    return func(**kwargs)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/sklearn.py", line 1143, in fit
    train_dmatrix, evals = _wrap_evaluation_matrices(
  File "/usr/local/lib/python3.10/dist-packages/xgboost/sklearn.py", line 603, in _wrap_evaluation_matrices
    train_dmatrix = create_dmatrix(
  File "/usr/local/lib/python3.10/dist-packages/xgboost/sklearn.py", line 1065, in _create_dmatrix
    return QuantileDMatrix(
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 726, in inner_f
    return func(**kwargs)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 1573, in __init__
    self._init(
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 1632, in _init
    it.reraise()
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 569, in reraise
    raise exc  # pylint: disable=raising-bad-type
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 550, in _handle_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 637, in <lambda>
    return self._handle_exception(lambda: self.next(input_data), 0)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/data.py", line 1402, in next
    input_data(**self.kwargs)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 726, in inner_f
    return func(**kwargs)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 617, in input_data
    new, cat_codes, feature_names, feature_types = _proxy_transform(
  File "/usr/local/lib/python3.10/dist-packages/xgboost/data.py", line 1429, in _proxy_transform
    data, _ = _ensure_np_dtype(data, data.dtype)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/data.py", line 224, in _ensure_np_dtype
    data = data.astype(dtype, copy=False)
ValueError: could not convert string to float: 'M'

--------------------------------------------------------------------------------
4 fits failed with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py", line 866, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 726, in inner_f
    return func(**kwargs)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/sklearn.py", line 1143, in fit
    train_dmatrix, evals = _wrap_evaluation_matrices(
  File "/usr/local/lib/python3.10/dist-packages/xgboost/sklearn.py", line 603, in _wrap_evaluation_matrices
    train_dmatrix = create_dmatrix(
  File "/usr/local/lib/python3.10/dist-packages/xgboost/sklearn.py", line 1065, in _create_dmatrix
    return QuantileDMatrix(
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 726, in inner_f
    return func(**kwargs)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 1573, in __init__
    self._init(
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 1632, in _init
    it.reraise()
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 569, in reraise
    raise exc  # pylint: disable=raising-bad-type
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 550, in _handle_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 637, in <lambda>
    return self._handle_exception(lambda: self.next(input_data), 0)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/data.py", line 1402, in next
    input_data(**self.kwargs)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 726, in inner_f
    return func(**kwargs)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/core.py", line 617, in input_data
    new, cat_codes, feature_names, feature_types = _proxy_transform(
  File "/usr/local/lib/python3.10/dist-packages/xgboost/data.py", line 1429, in _proxy_transform
    data, _ = _ensure_np_dtype(data, data.dtype)
  File "/usr/local/lib/python3.10/dist-packages/xgboost/data.py", line 224, in _ensure_np_dtype
    data = data.astype(dtype, copy=False)
ValueError: could not convert string to float: 'F'

Versions

Python dependencies:
      sklearn: 1.6.1
          pip: 24.1.2
   setuptools: 75.1.0
        numpy: 1.26.4
        scipy: 1.13.1
       Cython: 3.0.11
       pandas: 2.2.2
   matplotlib: 3.7.5
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: mkl
    num_threads: 2
         prefix: libmkl_rt
       filepath: /usr/local/lib/libmkl_rt.so.2
        version: 2025.0.1-Product
threading_layer: gnu

       user_api: blas
   internal_api: mkl
    num_threads: 2
         prefix: libmkl_rt
       filepath: /usr/local/lib/python3.10/dist-packages/mkl_fft.libs/libmkl_rt-089e6a60.so.2
        version: 2025.0.1-Product
threading_layer: not specified

       user_api: openmp
   internal_api: openmp
    num_threads: 4
         prefix: libgomp
       filepath: /usr/lib/x86_64-linux-gnu/libgomp.so.1.0.0
        version: None

       user_api: blas
   internal_api: openblas
    num_threads: 4
         prefix: libopenblas
       filepath: /usr/local/lib/python3.10/dist-packages/scipy.libs/libopenblasp-r0-01191904.3.27.so
        version: 0.3.27
threading_layer: pthreads
   architecture: Haswell

       user_api: openmp
   internal_api: openmp
    num_threads: 4
         prefix: libgomp
       filepath: /usr/local/lib/python3.10/dist-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None

       user_api: openmp
   internal_api: openmp
    num_threads: 4
         prefix: libgomp
       filepath: /usr/local/lib/python3.10/dist-packages/xgboost.libs/libgomp-24e2ab19.so.1.0.0
        version: None

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Morty Proxy This is a proxified and sanitized view of the page, visit original site.