Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 65dfab0

Browse filesBrowse files
authored
FIX Fixes pandas extension arrays in check_array (#25813)
1 parent e75d8a6 commit 65dfab0
Copy full SHA for 65dfab0

File tree

4 files changed

+40
-3
lines changed
Filter options

4 files changed

+40
-3
lines changed

‎doc/whats_new/v1.3.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.3.rst
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ Changelog
407407
:pr:`25733` by :user:`Brigitta Sipőcz <bsipocz>` and
408408
:user:`Jérémie du Boisberranger <jeremiedbb>`.
409409

410+
- |FIX| Fixes :func:`utils.validation.check_array` to properly convert pandas
411+
extension arrays. :pr:`25813` by `Thomas Fan`_.
412+
410413
:mod:`sklearn.semi_supervised`
411414
..............................
412415

‎sklearn/preprocessing/tests/test_label.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/tests/test_label.py
+6-2Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,19 @@ def test_label_binarizer_set_label_encoding():
118118

119119

120120
@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
121-
def test_label_binarizer_pandas_nullable(dtype):
121+
@pytest.mark.parametrize("unique_first", [True, False])
122+
def test_label_binarizer_pandas_nullable(dtype, unique_first):
122123
"""Checks that LabelBinarizer works with pandas nullable dtypes.
123124
124125
Non-regression test for gh-25637.
125126
"""
126127
pd = pytest.importorskip("pandas")
127-
from sklearn.preprocessing import LabelBinarizer
128128

129129
y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype)
130+
if unique_first:
131+
# Calling unique creates a pandas array which has a different interface
132+
# compared to a pandas Series. Specifically, pandas arrays do not have "iloc".
133+
y_true = y_true.unique()
130134
lb = LabelBinarizer().fit(y_true)
131135
y_out = lb.transform([1, 0])
132136

‎sklearn/utils/tests/test_validation.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_validation.py
+19Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,25 @@ def test_boolean_series_remains_boolean():
17621762
assert_array_equal(res, expected)
17631763

17641764

1765+
@pytest.mark.parametrize("input_values", [[0, 1, 0, 1, 0, np.nan], [0, 1, 0, 1, 0, 1]])
1766+
def test_pandas_array_returns_ndarray(input_values):
1767+
"""Check pandas array with extensions dtypes returns a numeric ndarray.
1768+
1769+
Non-regression test for gh-25637.
1770+
"""
1771+
pd = importorskip("pandas")
1772+
input_series = pd.array(input_values, dtype="Int32")
1773+
result = check_array(
1774+
input_series,
1775+
dtype=None,
1776+
ensure_2d=False,
1777+
allow_nd=False,
1778+
force_all_finite=False,
1779+
)
1780+
assert np.issubdtype(result.dtype.kind, np.floating)
1781+
assert_allclose(result, input_values)
1782+
1783+
17651784
@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
17661785
def test_check_array_array_api_has_non_finite(array_namespace):
17671786
"""Checks that Array API arrays checks non-finite correctly."""

‎sklearn/utils/validation.py

Copy file name to clipboardExpand all lines: sklearn/utils/validation.py
+12-1Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,15 @@ def _pandas_dtype_needs_early_conversion(pd_dtype):
626626
return False
627627

628628

629+
def _is_extension_array_dtype(array):
630+
try:
631+
from pandas.api.types import is_extension_array_dtype
632+
633+
return is_extension_array_dtype(array)
634+
except ImportError:
635+
return False
636+
637+
629638
def check_array(
630639
array,
631640
accept_sparse=False,
@@ -777,7 +786,9 @@ def check_array(
777786
if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig):
778787
dtype_orig = np.result_type(*dtypes_orig)
779788

780-
elif hasattr(array, "iloc") and hasattr(array, "dtype"):
789+
elif (_is_extension_array_dtype(array) or hasattr(array, "iloc")) and hasattr(
790+
array, "dtype"
791+
):
781792
# array is a pandas series
782793
pandas_requires_conversion = _pandas_dtype_needs_early_conversion(array.dtype)
783794
if isinstance(array.dtype, np.dtype):

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.