Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ff1c6f3

Browse filesBrowse files
iofallarisayosh
andauthored
FIX Remove validation from __init__ and set_params for ColumnTransformer (#22537)
Co-authored-by: iofall <50991099+iofall@users.noreply.github.com> Co-authored-by: arisayosh <15692997+arisayosh@users.noreply.github.com>
1 parent 6ab950e commit ff1c6f3
Copy full SHA for ff1c6f3

File tree

Expand file treeCollapse file tree

4 files changed

+27
-11
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+27
-11
lines changed

‎doc/whats_new/v1.1.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.1.rst
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ Changelog
171171
`-1` and the original warning message is shown.
172172
:pr:`22217` by :user:`Meekail Zain <micky774>`.
173173

174+
:mod:`sklearn.compose`
175+
......................
176+
177+
- |Fix| :class:`compose.ColumnTransformer` now removes validation errors from
178+
`__init__` and `set_params` methods.
179+
:pr:`22537` by :user:`iofall <iofall>` and :user:`Arisa Y. <arisayosh>`.
180+
174181
:mod:`sklearn.cross_decomposition`
175182
..................................
176183

‎sklearn/compose/_column_transformer.py

Copy file name to clipboardExpand all lines: sklearn/compose/_column_transformer.py
+11-5Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,20 @@ def _transformers(self):
222222
of get_params via BaseComposition._get_params which expects lists
223223
of tuples of len 2.
224224
"""
225-
return [(name, trans) for name, trans, _ in self.transformers]
225+
try:
226+
return [(name, trans) for name, trans, _ in self.transformers]
227+
except (TypeError, ValueError):
228+
return self.transformers
226229

227230
@_transformers.setter
228231
def _transformers(self, value):
229-
self.transformers = [
230-
(name, trans, col)
231-
for ((name, trans), (_, _, col)) in zip(value, self.transformers)
232-
]
232+
try:
233+
self.transformers = [
234+
(name, trans, col)
235+
for ((name, trans), (_, _, col)) in zip(value, self.transformers)
236+
]
237+
except (TypeError, ValueError):
238+
self.transformers = value
233239

234240
def get_params(self, deep=True):
235241
"""Get parameters for this estimator.

‎sklearn/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_common.py
+1-2Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,6 @@ def test_transformers_get_feature_names_out(transformer):
413413

414414

415415
VALIDATE_ESTIMATOR_INIT = [
416-
"ColumnTransformer",
417416
"SGDOneClassSVM",
418417
"TheilSenRegressor",
419418
"TweedieRegressor",
@@ -436,7 +435,7 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
436435
if param.kind != Parameter.VAR_KEYWORD
437436
]
438437

439-
smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), {}, []]
438+
smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), [1], {}, []]
440439
for value in smoke_test_values:
441440
new_params = {key: value for key in params}
442441

‎sklearn/utils/metaestimators.py

Copy file name to clipboardExpand all lines: sklearn/utils/metaestimators.py
+8-4Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from operator import attrgetter
99
from functools import update_wrapper
1010
import numpy as np
11+
from contextlib import suppress
1112

1213
from ..utils import _safe_indexing
1314
from ..utils._tags import _safe_tags
@@ -56,10 +57,13 @@ def _set_params(self, attr, **params):
5657
items = getattr(self, attr)
5758
if isinstance(items, list) and items:
5859
# Get item names used to identify valid names in params
59-
item_names, _ = zip(*items)
60-
for name in list(params.keys()):
61-
if "__" not in name and name in item_names:
62-
self._replace_estimator(attr, name, params.pop(name))
60+
# `zip` raises a TypeError when `items` does not contains
61+
# elements of length 2
62+
with suppress(TypeError):
63+
item_names, _ = zip(*items)
64+
for name in list(params.keys()):
65+
if "__" not in name and name in item_names:
66+
self._replace_estimator(attr, name, params.pop(name))
6367

6468
# 3. Step parameters and other initialisation arguments
6569
super().set_params(**params)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.