Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 793fff5

Browse filesBrowse files
committed
ENH: ensure find-like ufuncs convert arguments to common dtypes
1 parent 6d4c2c4 commit 793fff5
Copy full SHA for 793fff5

File tree

Expand file treeCollapse file tree

3 files changed

+68
-19
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+68
-19
lines changed

‎numpy/_core/strings.py

Copy file name to clipboardExpand all lines: numpy/_core/strings.py
+38-16Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def find(a, sub, start=0, end=None):
235235
236236
"""
237237
end = end if end is not None else MAX
238+
a = np.asanyarray(a)
239+
sub = np.asanyarray(sub, dtype=getattr(sub, "dtype", a.dtype))
238240
return _find_ufunc(a, sub, start, end)
239241

240242

@@ -265,6 +267,8 @@ def rfind(a, sub, start=0, end=None):
265267
266268
"""
267269
end = end if end is not None else MAX
270+
a = np.asanyarray(a)
271+
sub = np.asanyarray(sub, dtype=getattr(sub, "dtype", a.dtype))
268272
return _rfind_ufunc(a, sub, start, end)
269273

270274

@@ -277,6 +281,7 @@ def index(a, sub, start=0, end=None):
277281
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
278282
279283
sub : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
284+
The substring to search for.
280285
281286
start, end : array_like, with any integer dtype, optional
282287
@@ -297,6 +302,8 @@ def index(a, sub, start=0, end=None):
297302
298303
"""
299304
end = end if end is not None else MAX
305+
a = np.asanyarray(a)
306+
sub = np.asanyarray(sub, dtype=getattr(sub, "dtype", a.dtype))
300307
return _index_ufunc(a, sub, start, end)
301308

302309

@@ -307,9 +314,10 @@ def rindex(a, sub, start=0, end=None):
307314
308315
Parameters
309316
----------
310-
a : array-like, with `np.bytes_` or `np.str_` dtype
317+
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
311318
312-
sub : array-like, with `np.bytes_` or `np.str_` dtype
319+
sub : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
320+
The substring to search for.
313321
314322
start, end : array-like, with any integer dtype, optional
315323
@@ -327,9 +335,11 @@ def rindex(a, sub, start=0, end=None):
327335
>>> a = np.array(["Computer Science"])
328336
>>> np.strings.rindex(a, "Science", start=0, end=None)
329337
array([9])
330-
338+
331339
"""
332340
end = end if end is not None else MAX
341+
a = np.asanyarray(a)
342+
sub = np.asanyarray(sub, dtype=getattr(sub, "dtype", a.dtype))
333343
return _rindex_ufunc(a, sub, start, end)
334344

335345

@@ -373,6 +383,8 @@ def count(a, sub, start=0, end=None):
373383
374384
"""
375385
end = end if end is not None else MAX
386+
a = np.asanyarray(a)
387+
sub = np.asanyarray(sub, dtype=getattr(sub, "dtype", a.dtype))
376388
return _count_ufunc(a, sub, start, end)
377389

378390

@@ -386,6 +398,7 @@ def startswith(a, prefix, start=0, end=None):
386398
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
387399
388400
prefix : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
401+
The substring to search for.
389402
390403
start, end : array_like, with any integer dtype
391404
With ``start``, test beginning at that position. With ``end``,
@@ -402,6 +415,8 @@ def startswith(a, prefix, start=0, end=None):
402415
403416
"""
404417
end = end if end is not None else MAX
418+
a = np.asanyarray(a)
419+
prefix = np.asanyarray(prefix, dtype=getattr(prefix, "dtype", a.dtype))
405420
return _startswith_ufunc(a, prefix, start, end)
406421

407422

@@ -415,6 +430,7 @@ def endswith(a, suffix, start=0, end=None):
415430
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
416431
417432
suffix : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
433+
The substring to search for.
418434
419435
start, end : array_like, with any integer dtype
420436
With ``start``, test beginning at that position. With ``end``,
@@ -441,6 +457,8 @@ def endswith(a, suffix, start=0, end=None):
441457
442458
"""
443459
end = end if end is not None else MAX
460+
a = np.asanyarray(a)
461+
suffix = np.asanyarray(suffix, dtype=getattr(suffix, "dtype", a.dtype))
444462
return _endswith_ufunc(a, suffix, start, end)
445463

446464

@@ -627,7 +645,7 @@ def center(a, width, fillchar=' '):
627645
"""
628646
a = np.asanyarray(a)
629647
width = np.maximum(str_len(a), width)
630-
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
648+
fillchar = np.asanyarray(fillchar, getattr(fillchar, "dtype", a.dtype))
631649

632650
if np.any(str_len(fillchar) != 1):
633651
raise TypeError(
@@ -683,7 +701,7 @@ def ljust(a, width, fillchar=' '):
683701
"""
684702
a = np.asanyarray(a)
685703
width = np.maximum(str_len(a), width)
686-
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
704+
fillchar = np.asanyarray(fillchar, getattr(fillchar, "dtype", a.dtype))
687705

688706
if np.any(str_len(fillchar) != 1):
689707
raise TypeError(
@@ -739,7 +757,7 @@ def rjust(a, width, fillchar=' '):
739757
"""
740758
a = np.asanyarray(a)
741759
width = np.maximum(str_len(a), width)
742-
fillchar = np.asanyarray(fillchar, dtype=a.dtype)
760+
fillchar = np.asanyarray(fillchar, getattr(fillchar, "dtype", a.dtype))
743761

744762
if np.any(str_len(fillchar) != 1):
745763
raise TypeError(
@@ -838,7 +856,9 @@ def lstrip(a, chars=None):
838856
"""
839857
if chars is None:
840858
return _lstrip_whitespace(a)
841-
return _lstrip_chars(a, chars)
859+
a = np.asanyarray(a)
860+
return _lstrip_chars(
861+
a, np.asanyarray(chars, getattr(chars, "dtype", a.dtype)))
842862

843863

844864
def rstrip(a, chars=None):
@@ -879,7 +899,9 @@ def rstrip(a, chars=None):
879899
"""
880900
if chars is None:
881901
return _rstrip_whitespace(a)
882-
return _rstrip_chars(a, chars)
902+
a = np.asanyarray(a)
903+
return _rstrip_chars(
904+
a, np.asanyarray(chars, getattr(chars, "dtype", a.dtype)))
883905

884906

885907
def strip(a, chars=None):
@@ -924,7 +946,9 @@ def strip(a, chars=None):
924946
"""
925947
if chars is None:
926948
return _strip_whitespace(a)
927-
return _strip_chars(a, chars)
949+
a = np.asanyarray(a)
950+
return _strip_chars(
951+
a, np.asanyarray(chars, getattr(chars, "dtype", a.dtype)))
928952

929953

930954
def upper(a):
@@ -1120,9 +1144,9 @@ def replace(a, old, new, count=-1):
11201144
11211145
Parameters
11221146
----------
1123-
a : array_like, with ``bytes_`` or ``str_`` dtype
1147+
a : array_like, with ``StringDType``, ``bytes_`` or ``str_`` dtype
11241148
1125-
old, new : array_like, with ``bytes_`` or ``str_`` dtype
1149+
old, new : array_like, with ``StringDType``, ``bytes_`` or ``str_`` dtype
11261150
11271151
count : array_like, with ``int_`` dtype
11281152
If the optional argument ``count`` is given, only the first
@@ -1147,7 +1171,7 @@ def replace(a, old, new, count=-1):
11471171
>>> a = np.array(["The dish is fresh", "This is it"])
11481172
>>> np.strings.replace(a, 'is', 'was')
11491173
array(['The dwash was fresh', 'Thwas was it'], dtype='<U19')
1150-
1174+
11511175
"""
11521176
arr = np.asanyarray(a)
11531177
a_dt = arr.dtype
@@ -1361,8 +1385,7 @@ def partition(a, sep):
13611385
13621386
"""
13631387
a = np.asanyarray(a)
1364-
# TODO switch to copy=False when issues around views are fixed
1365-
sep = np.array(sep, dtype=a.dtype, copy=True, subok=True)
1388+
sep = np.asanyarray(sep, dtype=getattr(sep, "dtype", a.dtype))
13661389
if a.dtype.char == "T":
13671390
return _partition(a, sep)
13681391

@@ -1426,8 +1449,7 @@ def rpartition(a, sep):
14261449
14271450
"""
14281451
a = np.asanyarray(a)
1429-
# TODO switch to copy=False when issues around views are fixed
1430-
sep = np.array(sep, dtype=a.dtype, copy=True, subok=True)
1452+
sep = np.asanyarray(sep, dtype=getattr(sep, "dtype", a.dtype))
14311453
if a.dtype.char == "T":
14321454
return _rpartition(a, sep)
14331455

‎numpy/_core/tests/test_stringdtype.py

Copy file name to clipboardExpand all lines: numpy/_core/tests/test_stringdtype.py
+24Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,22 @@ def test_unary(string_array, unicode_array, function_name):
11931193
NULLS_ALWAYS_ERROR
11941194
)
11951195

1196+
STRING_SECOND_ARGUMENT = [
1197+
"find",
1198+
"rfind",
1199+
"index",
1200+
"rindex",
1201+
"count",
1202+
"startswith",
1203+
"endswith",
1204+
"lstrip",
1205+
"rstrip",
1206+
"strip",
1207+
"partition",
1208+
"rpartition",
1209+
"replace",
1210+
]
1211+
11961212

11971213
def call_func(func, args, array, sanitize=True):
11981214
if args == (None, None):
@@ -1221,6 +1237,14 @@ def test_binary(string_array, unicode_array, function_name, args):
12211237
ures = ures.astype(StringDType())
12221238
assert_array_equal(sres, ures)
12231239

1240+
if function_name in STRING_SECOND_ARGUMENT:
1241+
# call again with a non-default stringdtype instance, this should
1242+
# work even though the inferred dtype for the second argument is
1243+
# the default stringdtype instance
1244+
sres = call_func(func, args,
1245+
string_array.astype(StringDType(na_object="foobar")))
1246+
assert_array_equal(sres, ures)
1247+
12241248
dtype = string_array.dtype
12251249
if function_name not in SUPPORTS_NULLS or not hasattr(dtype, "na_object"):
12261250
return

‎numpy/_core/tests/test_strings.py

Copy file name to clipboardExpand all lines: numpy/_core/tests/test_strings.py
+6-3Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,8 @@ def test_center(self):
10761076
res = np.array("*s*", dtype="S")
10771077
assert_array_equal(np.strings.center(buf, 3, fill), res)
10781078

1079-
with pytest.raises(ValueError, match="'ascii' codec can't encode"):
1079+
with pytest.raises(
1080+
ValueError, match="non-ascii fill character is not allowed"):
10801081
buf = np.array("s", dtype="S")
10811082
fill = np.array("😊", dtype="U")
10821083
np.strings.center(buf, 3, fill)
@@ -1092,7 +1093,8 @@ def test_ljust(self):
10921093
res = np.array("s**", dtype="S")
10931094
assert_array_equal(np.strings.ljust(buf, 3, fill), res)
10941095

1095-
with pytest.raises(ValueError, match="'ascii' codec can't encode"):
1096+
with pytest.raises(
1097+
ValueError, match="non-ascii fill character is not allowed"):
10961098
buf = np.array("s", dtype="S")
10971099
fill = np.array("😊", dtype="U")
10981100
np.strings.ljust(buf, 3, fill)
@@ -1108,7 +1110,8 @@ def test_rjust(self):
11081110
res = np.array("**s", dtype="S")
11091111
assert_array_equal(np.strings.rjust(buf, 3, fill), res)
11101112

1111-
with pytest.raises(ValueError, match="'ascii' codec can't encode"):
1113+
with pytest.raises(
1114+
ValueError, match="non-ascii fill character is not allowed"):
11121115
buf = np.array("s", dtype="S")
11131116
fill = np.array("😊", dtype="U")
11141117
np.strings.rjust(buf, 3, fill)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.