Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6894a9b

Browse filesBrowse files
authored
FIX Fixes OrdinalEncoder.inverse_tranform nan encoded values (#24087)
1 parent c8f68e8 commit 6894a9b
Copy full SHA for 6894a9b

File tree

3 files changed

+77
-5
lines changed
Filter options

3 files changed

+77
-5
lines changed

‎doc/whats_new/v1.1.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.1.rst
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ Changelog
3636
a node if there are duplicates in the dataset.
3737
:pr:`23395` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
3838

39+
:mod:`sklearn.preprocessing`
40+
............................
41+
42+
- |Fix| :meth:`preprocessing.OrdinalEncoder.inverse_transform` correctly handles
43+
use cases where `unknown_value` or `encoded_missing_value` is `nan`. :pr:`24087`
44+
by `Thomas Fan`_.
45+
3946
.. _changes_1_1_1:
4047

4148
Version 1.1.1

‎sklearn/preprocessing/_encoders.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/_encoders.py
+9-5Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,19 +1371,23 @@ def inverse_transform(self, X):
13711371
found_unknown = {}
13721372

13731373
for i in range(n_features):
1374-
labels = X[:, i].astype("int64", copy=False)
1374+
labels = X[:, i]
13751375

13761376
# replace values of X[:, i] that were nan with actual indices
13771377
if i in self._missing_indices:
1378-
X_i_mask = _get_mask(X[:, i], self.encoded_missing_value)
1378+
X_i_mask = _get_mask(labels, self.encoded_missing_value)
13791379
labels[X_i_mask] = self._missing_indices[i]
13801380

13811381
if self.handle_unknown == "use_encoded_value":
1382-
unknown_labels = labels == self.unknown_value
1383-
X_tr[:, i] = self.categories_[i][np.where(unknown_labels, 0, labels)]
1382+
unknown_labels = _get_mask(labels, self.unknown_value)
1383+
1384+
known_labels = ~unknown_labels
1385+
X_tr[known_labels, i] = self.categories_[i][
1386+
labels[known_labels].astype("int64", copy=False)
1387+
]
13841388
found_unknown[i] = unknown_labels
13851389
else:
1386-
X_tr[:, i] = self.categories_[i][labels]
1390+
X_tr[:, i] = self.categories_[i][labels.astype("int64", copy=False)]
13871391

13881392
# insert None values for unknown values
13891393
if found_unknown:

‎sklearn/preprocessing/tests/test_encoders.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/tests/test_encoders.py
+61Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,6 +1837,15 @@ def test_ordinal_encoder_unknown_missing_interaction():
18371837
X_test_trans = oe.transform(X_test)
18381838
assert_allclose(X_test_trans, [[np.nan], [-3]])
18391839

1840+
# Non-regression test for #24082
1841+
X_roundtrip = oe.inverse_transform(X_test_trans)
1842+
1843+
# np.nan is unknown so it maps to None
1844+
assert X_roundtrip[0][0] is None
1845+
1846+
# -3 is the encoded missing value so it maps back to nan
1847+
assert np.isnan(X_roundtrip[1][0])
1848+
18401849

18411850
@pytest.mark.parametrize("with_pandas", [True, False])
18421851
def test_ordinal_encoder_encoded_missing_value_error(with_pandas):
@@ -1862,3 +1871,55 @@ def test_ordinal_encoder_encoded_missing_value_error(with_pandas):
18621871

18631872
with pytest.raises(ValueError, match=error_msg):
18641873
oe.fit(X)
1874+
1875+
1876+
@pytest.mark.parametrize(
1877+
"X_train, X_test_trans_expected, X_roundtrip_expected",
1878+
[
1879+
(
1880+
# missing value is not in training set
1881+
# inverse transform will considering encoded nan as unknown
1882+
np.array([["a"], ["1"]], dtype=object),
1883+
[[0], [np.nan], [np.nan]],
1884+
np.asarray([["1"], [None], [None]], dtype=object),
1885+
),
1886+
(
1887+
# missing value in training set,
1888+
# inverse transform will considering encoded nan as missing
1889+
np.array([[np.nan], ["1"], ["a"]], dtype=object),
1890+
[[0], [np.nan], [np.nan]],
1891+
np.asarray([["1"], [np.nan], [np.nan]], dtype=object),
1892+
),
1893+
],
1894+
)
1895+
def test_ordinal_encoder_unknown_missing_interaction_both_nan(
1896+
X_train, X_test_trans_expected, X_roundtrip_expected
1897+
):
1898+
"""Check transform when unknown_value and encoded_missing_value is nan.
1899+
1900+
Non-regression test for #24082.
1901+
"""
1902+
oe = OrdinalEncoder(
1903+
handle_unknown="use_encoded_value",
1904+
unknown_value=np.nan,
1905+
encoded_missing_value=np.nan,
1906+
).fit(X_train)
1907+
1908+
X_test = np.array([["1"], [np.nan], ["b"]])
1909+
X_test_trans = oe.transform(X_test)
1910+
1911+
# both nan and unknown are encoded as nan
1912+
assert_allclose(X_test_trans, X_test_trans_expected)
1913+
X_roundtrip = oe.inverse_transform(X_test_trans)
1914+
1915+
n_samples = X_roundtrip_expected.shape[0]
1916+
for i in range(n_samples):
1917+
expected_val = X_roundtrip_expected[i, 0]
1918+
val = X_roundtrip[i, 0]
1919+
1920+
if expected_val is None:
1921+
assert val is None
1922+
elif is_scalar_nan(expected_val):
1923+
assert np.isnan(val)
1924+
else:
1925+
assert val == expected_val

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.