Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

BUG: ensure find-like ufuncs convert arguments to common dtypes #26198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions 66 numpy/_core/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def find(a, sub, start=0, end=None):

"""
end = end if end is not None else MAX
a = np.asanyarray(a)
sub = np.asanyarray(sub, dtype=getattr(sub, "dtype", a.dtype))
return _find_ufunc(a, sub, start, end)


Expand Down Expand Up @@ -265,6 +267,8 @@ def rfind(a, sub, start=0, end=None):

"""
end = end if end is not None else MAX
a = np.asanyarray(a)
sub = np.asanyarray(sub, dtype=getattr(sub, "dtype", a.dtype))
return _rfind_ufunc(a, sub, start, end)


Expand All @@ -277,6 +281,7 @@ def index(a, sub, start=0, end=None):
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype

sub : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
The substring to search for.

start, end : array_like, with any integer dtype, optional

Expand All @@ -297,6 +302,8 @@ def index(a, sub, start=0, end=None):

"""
end = end if end is not None else MAX
a = np.asanyarray(a)
sub = np.asanyarray(sub, dtype=getattr(sub, "dtype", a.dtype))
return _index_ufunc(a, sub, start, end)


Expand All @@ -307,9 +314,10 @@ def rindex(a, sub, start=0, end=None):

Parameters
----------
a : array-like, with `np.bytes_` or `np.str_` dtype
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype

sub : array-like, with `np.bytes_` or `np.str_` dtype
sub : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
The substring to search for.

start, end : array-like, with any integer dtype, optional

Expand All @@ -327,9 +335,11 @@ def rindex(a, sub, start=0, end=None):
>>> a = np.array(["Computer Science"])
>>> np.strings.rindex(a, "Science", start=0, end=None)
array([9])

"""
end = end if end is not None else MAX
a = np.asanyarray(a)
sub = np.asanyarray(sub, dtype=getattr(sub, "dtype", a.dtype))
return _rindex_ufunc(a, sub, start, end)


Expand Down Expand Up @@ -373,6 +383,8 @@ def count(a, sub, start=0, end=None):

"""
end = end if end is not None else MAX
a = np.asanyarray(a)
sub = np.asanyarray(sub, dtype=getattr(sub, "dtype", a.dtype))
return _count_ufunc(a, sub, start, end)


Expand All @@ -386,6 +398,7 @@ def startswith(a, prefix, start=0, end=None):
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype

prefix : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
The substring to search for.

start, end : array_like, with any integer dtype
With ``start``, test beginning at that position. With ``end``,
Expand All @@ -402,6 +415,8 @@ def startswith(a, prefix, start=0, end=None):

"""
end = end if end is not None else MAX
a = np.asanyarray(a)
prefix = np.asanyarray(prefix, dtype=getattr(prefix, "dtype", a.dtype))
return _startswith_ufunc(a, prefix, start, end)


Expand All @@ -415,6 +430,7 @@ def endswith(a, suffix, start=0, end=None):
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype

suffix : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
The substring to search for.

start, end : array_like, with any integer dtype
With ``start``, test beginning at that position. With ``end``,
Expand All @@ -441,6 +457,8 @@ def endswith(a, suffix, start=0, end=None):

"""
end = end if end is not None else MAX
a = np.asanyarray(a)
suffix = np.asanyarray(suffix, dtype=getattr(suffix, "dtype", a.dtype))
return _endswith_ufunc(a, suffix, start, end)


Expand Down Expand Up @@ -595,7 +613,9 @@ def center(a, width, fillchar=' '):
width : array_like, with any integer dtype
The length of the resulting strings, unless ``width < str_len(a)``.
fillchar : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Optional padding character to use (default is space).
Optional padding character to use (default is space). If ``a`` and
``fillchar`` have fixed-width dtypes, then ``fillchar`` will be
truncated to the length of ``a``.

Returns
-------
Expand Down Expand Up @@ -627,7 +647,7 @@ def center(a, width, fillchar=' '):
"""
a = np.asanyarray(a)
width = np.maximum(str_len(a), width)
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
fillchar = np.asanyarray(fillchar, a.dtype)

if np.any(str_len(fillchar) != 1):
raise TypeError(
Expand All @@ -654,7 +674,9 @@ def ljust(a, width, fillchar=' '):
width : array_like, with any integer dtype
The length of the resulting strings, unless ``width < str_len(a)``.
fillchar : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Optional character to use for padding (default is space).
Optional character to use for padding (default is space). If ``a`` and
``fillchar`` have fixed-width dtypes, then ``fillchar`` will be
truncated to the length of ``a``.

Returns
-------
Expand Down Expand Up @@ -683,7 +705,7 @@ def ljust(a, width, fillchar=' '):
"""
a = np.asanyarray(a)
width = np.maximum(str_len(a), width)
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
fillchar = np.asanyarray(fillchar, a.dtype)

if np.any(str_len(fillchar) != 1):
raise TypeError(
Expand All @@ -710,7 +732,9 @@ def rjust(a, width, fillchar=' '):
width : array_like, with any integer dtype
The length of the resulting strings, unless ``width < str_len(a)``.
fillchar : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Optional padding character to use (default is space).
Optional padding character to use (default is space). If ``a`` and
``fillchar`` have fixed-width dtypes, then ``fillchar`` will be
truncated to the length of ``a``.

Returns
-------
Expand Down Expand Up @@ -739,7 +763,7 @@ def rjust(a, width, fillchar=' '):
"""
a = np.asanyarray(a)
width = np.maximum(str_len(a), width)
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
fillchar = np.asanyarray(fillchar, a.dtype)

if np.any(str_len(fillchar) != 1):
raise TypeError(
Expand Down Expand Up @@ -838,7 +862,9 @@ def lstrip(a, chars=None):
"""
if chars is None:
return _lstrip_whitespace(a)
return _lstrip_chars(a, chars)
a = np.asanyarray(a)
return _lstrip_chars(
a, np.asanyarray(chars, getattr(chars, "dtype", a.dtype)))


def rstrip(a, chars=None):
Expand Down Expand Up @@ -879,7 +905,9 @@ def rstrip(a, chars=None):
"""
if chars is None:
return _rstrip_whitespace(a)
return _rstrip_chars(a, chars)
a = np.asanyarray(a)
return _rstrip_chars(
a, np.asanyarray(chars, getattr(chars, "dtype", a.dtype)))


def strip(a, chars=None):
Expand Down Expand Up @@ -924,7 +952,9 @@ def strip(a, chars=None):
"""
if chars is None:
return _strip_whitespace(a)
return _strip_chars(a, chars)
a = np.asanyarray(a)
return _strip_chars(
a, np.asanyarray(chars, getattr(chars, "dtype", a.dtype)))


def upper(a):
Expand Down Expand Up @@ -1120,9 +1150,9 @@ def replace(a, old, new, count=-1):

Parameters
----------
a : array_like, with ``bytes_`` or ``str_`` dtype
a : array_like, with ``StringDType``, ``bytes_`` or ``str_`` dtype

old, new : array_like, with ``bytes_`` or ``str_`` dtype
old, new : array_like, with ``StringDType``, ``bytes_`` or ``str_`` dtype

count : array_like, with ``int_`` dtype
If the optional argument ``count`` is given, only the first
Expand All @@ -1147,7 +1177,7 @@ def replace(a, old, new, count=-1):
>>> a = np.array(["The dish is fresh", "This is it"])
>>> np.strings.replace(a, 'is', 'was')
array(['The dwash was fresh', 'Thwas was it'], dtype='<U19')

"""
arr = np.asanyarray(a)
a_dt = arr.dtype
Expand Down Expand Up @@ -1361,8 +1391,7 @@ def partition(a, sep):

"""
a = np.asanyarray(a)
# TODO switch to copy=False when issues around views are fixed
sep = np.array(sep, dtype=a.dtype, copy=True, subok=True)
sep = np.asanyarray(sep, dtype=getattr(sep, "dtype", a.dtype))
if a.dtype.char == "T":
return _partition(a, sep)

Expand Down Expand Up @@ -1426,8 +1455,7 @@ def rpartition(a, sep):

"""
a = np.asanyarray(a)
# TODO switch to copy=False when issues around views are fixed
sep = np.array(sep, dtype=a.dtype, copy=True, subok=True)
sep = np.asanyarray(sep, dtype=getattr(sep, "dtype", a.dtype))
if a.dtype.char == "T":
return _rpartition(a, sep)

Expand Down
24 changes: 24 additions & 0 deletions 24 numpy/_core/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,22 @@ def test_unary(string_array, unicode_array, function_name):
NULLS_ALWAYS_ERROR
)

STRING_SECOND_ARGUMENT = [
"find",
"rfind",
"index",
"rindex",
"count",
"startswith",
"endswith",
"lstrip",
"rstrip",
"strip",
"partition",
"rpartition",
"replace",
]


def call_func(func, args, array, sanitize=True):
if args == (None, None):
Expand Down Expand Up @@ -1221,6 +1237,14 @@ def test_binary(string_array, unicode_array, function_name, args):
ures = ures.astype(StringDType())
assert_array_equal(sres, ures)

if function_name in STRING_SECOND_ARGUMENT:
# call again with a non-default stringdtype instance, this should
# work even though the inferred dtype for the second argument is
# the default stringdtype instance
sres = call_func(func, args,
string_array.astype(StringDType(na_object="foobar")))
assert_array_equal(sres, ures)

dtype = string_array.dtype
if function_name not in SUPPORTS_NULLS or not hasattr(dtype, "na_object"):
return
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.