Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d95ecbb

Browse filesBrowse files
authored
FIX Keeps namedtuple's class when transform returns a tuple (#26121)
1 parent 379d54e commit d95ecbb
Copy full SHA for d95ecbb

File tree

3 files changed

+30
-1
lines changed
Filter options

3 files changed

+30
-1
lines changed

‎doc/whats_new/v1.3.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.3.rst
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ Changelog
166166
- |Feature| A `__sklearn_clone__` protocol is now available to override the
167167
default behavior of :func:`base.clone`. :pr:`24568` by `Thomas Fan`_.
168168

169+
- |Fix| :class:`base.TransformerMixin` now currently keeps a namedtuple's class
170+
if `transform` returns a namedtuple. :pr:`26121` by `Thomas Fan`_.
171+
169172
:mod:`sklearn.calibration`
170173
..........................
171174

‎sklearn/utils/_set_output.py

Copy file name to clipboardExpand all lines: sklearn/utils/_set_output.py
+6-1Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,15 @@ def wrapped(self, X, *args, **kwargs):
140140
data_to_wrap = f(self, X, *args, **kwargs)
141141
if isinstance(data_to_wrap, tuple):
142142
# only wrap the first output for cross decomposition
143-
return (
143+
return_tuple = (
144144
_wrap_data_with_container(method, data_to_wrap[0], X, self),
145145
*data_to_wrap[1:],
146146
)
147+
# Support for namedtuples `_make` is a documented API for namedtuples:
148+
# https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make
149+
if hasattr(type(data_to_wrap), "_make"):
150+
return type(data_to_wrap)._make(return_tuple)
151+
return return_tuple
147152

148153
return _wrap_data_with_container(method, data_to_wrap, X, self)
149154

‎sklearn/utils/tests/test_set_output.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_set_output.py
+21Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from collections import namedtuple
23

34
import numpy as np
45
from scipy.sparse import csr_matrix
@@ -292,3 +293,23 @@ def test_set_output_pandas_keep_index():
292293

293294
X_trans = est.transform(X)
294295
assert_array_equal(X_trans.index, ["s0", "s1"])
296+
297+
298+
class EstimatorReturnTuple(_SetOutputMixin):
299+
def __init__(self, OutputTuple):
300+
self.OutputTuple = OutputTuple
301+
302+
def transform(self, X, y=None):
303+
return self.OutputTuple(X, 2 * X)
304+
305+
306+
def test_set_output_named_tuple_out():
307+
"""Check that namedtuples are kept by default."""
308+
Output = namedtuple("Output", "X, Y")
309+
X = np.asarray([[1, 2, 3]])
310+
est = EstimatorReturnTuple(OutputTuple=Output)
311+
X_trans = est.transform(X)
312+
313+
assert isinstance(X_trans, Output)
314+
assert_array_equal(X_trans.X, X)
315+
assert_array_equal(X_trans.Y, 2 * X)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.