Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ea03b3e

Browse filesBrowse files
mario-at-intercomthomasjpfanlorentzenchrglemaitrecmarmo
committed
ENH Add feature_name_combiner to OneHotEncoder (scikit-learn#22506)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Chiara Marmo <cmarmo@users.noreply.github.com>
1 parent 4af933b commit ea03b3e
Copy full SHA for ea03b3e

File tree

Expand file treeCollapse file tree

3 files changed

+77
-1
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+77
-1
lines changed

‎doc/whats_new/v1.3.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.3.rst
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,13 @@ Changelog
144144

145145
:mod:`sklearn.preprocessing`
146146
............................
147+
148+
- |Enhancement| Adds a `feature_name_combiner` parameter to
149+
:class:`preprocessing.OneHotEncoder`. This specifies a custom callable to create
150+
feature names to be returned by :meth:`get_feature_names_out`.
151+
The callable combines input arguments `(input_feature, category)` to a string.
152+
:pr:`22506` by :user:`Mario Kostelac <mariokostelac>`.
153+
147154
- |Enhancement| Added support for `sample_weight` in
148155
:class:`preprocessing.KBinsDiscretizer`. This allows specifying the parameter
149156
`sample_weight` for each sample to be used while fitting. The option is only

‎sklearn/preprocessing/_encoders.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/_encoders.py
+44-1Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,17 @@ class OneHotEncoder(_BaseEncoder):
343343
.. versionadded:: 1.1
344344
Read more in the :ref:`User Guide <one_hot_encoder_infrequent_categories>`.
345345
346+
feature_name_combiner : "concat" or callable, default="concat"
347+
Callable with signature `def callable(input_feature, category)` that returns a
348+
string. This is used to create feature names to be returned by
349+
:meth:`get_feature_names_out`.
350+
351+
`"concat"` concatenates encoded feature name and category with
352+
`feature + "_" + str(category)`.E.g. feature X with values 1, 6, 7 create
353+
feature names `X_1, X_6, X_7`.
354+
355+
.. versionadded:: 1.3
356+
346357
Attributes
347358
----------
348359
categories_ : list of arrays
@@ -388,6 +399,13 @@ class OneHotEncoder(_BaseEncoder):
388399
389400
.. versionadded:: 1.0
390401
402+
feature_name_combiner : callable or None
403+
Callable with signature `def callable(input_feature, category)` that returns a
404+
string. This is used to create feature names to be returned by
405+
:meth:`get_feature_names_out`.
406+
407+
.. versionadded:: 1.3
408+
391409
See Also
392410
--------
393411
OrdinalEncoder : Performs an ordinal (integer)
@@ -442,6 +460,15 @@ class OneHotEncoder(_BaseEncoder):
442460
array([[0., 1., 0., 0.],
443461
[1., 0., 1., 0.]])
444462
463+
One can change the way feature names are created.
464+
465+
>>> def custom_combiner(feature, category):
466+
... return str(feature) + "_" + type(category).__name__ + "_" + str(category)
467+
>>> custom_fnames_enc = OneHotEncoder(feature_name_combiner=custom_combiner).fit(X)
468+
>>> custom_fnames_enc.get_feature_names_out()
469+
array(['x0_str_Female', 'x0_str_Male', 'x1_int_1', 'x1_int_2', 'x1_int_3'],
470+
dtype=object)
471+
445472
Infrequent categories are enabled by setting `max_categories` or `min_frequency`.
446473
447474
>>> import numpy as np
@@ -467,6 +494,7 @@ class OneHotEncoder(_BaseEncoder):
467494
],
468495
"sparse": [Hidden(StrOptions({"deprecated"})), "boolean"], # deprecated
469496
"sparse_output": ["boolean"],
497+
"feature_name_combiner": [StrOptions({"concat"}), callable],
470498
}
471499

472500
def __init__(
@@ -480,6 +508,7 @@ def __init__(
480508
handle_unknown="error",
481509
min_frequency=None,
482510
max_categories=None,
511+
feature_name_combiner="concat",
483512
):
484513
self.categories = categories
485514
# TODO(1.4): Remove self.sparse
@@ -490,6 +519,7 @@ def __init__(
490519
self.drop = drop
491520
self.min_frequency = min_frequency
492521
self.max_categories = max_categories
522+
self.feature_name_combiner = feature_name_combiner
493523

494524
@property
495525
def infrequent_categories_(self):
@@ -1060,13 +1090,26 @@ def get_feature_names_out(self, input_features=None):
10601090
for i, _ in enumerate(self.categories_)
10611091
]
10621092

1093+
name_combiner = self._check_get_feature_name_combiner()
10631094
feature_names = []
10641095
for i in range(len(cats)):
1065-
names = [input_features[i] + "_" + str(t) for t in cats[i]]
1096+
names = [name_combiner(input_features[i], t) for t in cats[i]]
10661097
feature_names.extend(names)
10671098

10681099
return np.array(feature_names, dtype=object)
10691100

1101+
def _check_get_feature_name_combiner(self):
1102+
if self.feature_name_combiner == "concat":
1103+
return lambda feature, category: feature + "_" + str(category)
1104+
else: # callable
1105+
dry_run_combiner = self.feature_name_combiner("feature", "category")
1106+
if not isinstance(dry_run_combiner, str):
1107+
raise TypeError(
1108+
"When `feature_name_combiner` is a callable, it should return a "
1109+
f"Python string. Got {type(dry_run_combiner)} instead."
1110+
)
1111+
return self.feature_name_combiner
1112+
10701113

10711114
class OrdinalEncoder(OneToOneFeatureMixin, _BaseEncoder):
10721115
"""

‎sklearn/preprocessing/tests/test_encoders.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/tests/test_encoders.py
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,32 @@ def test_one_hot_encoder_feature_names_unicode():
193193
assert_array_equal(["n👍me_c❤t1", "n👍me_dat2"], feature_names)
194194

195195

196+
def test_one_hot_encoder_custom_feature_name_combiner():
197+
"""Check the behaviour of `feature_name_combiner` as a callable."""
198+
199+
def name_combiner(feature, category):
200+
return feature + "_" + repr(category)
201+
202+
enc = OneHotEncoder(feature_name_combiner=name_combiner)
203+
X = np.array([["None", None]], dtype=object).T
204+
enc.fit(X)
205+
feature_names = enc.get_feature_names_out()
206+
assert_array_equal(["x0_'None'", "x0_None"], feature_names)
207+
feature_names = enc.get_feature_names_out(input_features=["a"])
208+
assert_array_equal(["a_'None'", "a_None"], feature_names)
209+
210+
def wrong_combiner(feature, category):
211+
# we should be returning a Python string
212+
return 0
213+
214+
enc = OneHotEncoder(feature_name_combiner=wrong_combiner).fit(X)
215+
err_msg = (
216+
"When `feature_name_combiner` is a callable, it should return a Python string."
217+
)
218+
with pytest.raises(TypeError, match=err_msg):
219+
enc.get_feature_names_out()
220+
221+
196222
def test_one_hot_encoder_set_params():
197223
X = np.array([[1, 2]]).T
198224
oh = OneHotEncoder()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.