Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 22f4a1b

Browse filesBrowse files
committed
TYP: np.argmin changes
1 parent 4e628f4 commit 22f4a1b
Copy full SHA for 22f4a1b

File tree

Expand file treeCollapse file tree

6 files changed

+33
-25
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+33
-25
lines changed

‎numpy/__init__.pyi

Copy file name to clipboardExpand all lines: numpy/__init__.pyi
+9-8Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ _CharacterItemT_co = TypeVar("_CharacterItemT_co", bound=_CharLike_co, default=_
826826
_TD64ItemT_co = TypeVar("_TD64ItemT_co", bound=dt.timedelta | int | None, default=dt.timedelta | int | None, covariant=True)
827827
_DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True)
828828
_TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit)
829+
_BoolOrIntArrayT = TypeVar("_BoolOrIntArrayT", bound=NDArray[integer | np.bool])
829830

830831
### Type Aliases (for internal use only)
831832

@@ -1704,18 +1705,18 @@ class _ArrayOrScalarCommon:
17041705
@overload # axis=index, out=None (default)
17051706
def argmax(self, /, axis: SupportsIndex, out: None = None, *, keepdims: builtins.bool = False) -> Any: ...
17061707
@overload # axis=index, out=ndarray
1707-
def argmax(self, /, axis: SupportsIndex | None, out: _ArrayT, *, keepdims: builtins.bool = False) -> _ArrayT: ...
1708+
def argmax(self, /, axis: SupportsIndex | None, out: _BoolOrIntArrayT, *, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17081709
@overload
1709-
def argmax(self, /, axis: SupportsIndex | None = None, *, out: _ArrayT, keepdims: builtins.bool = False) -> _ArrayT: ...
1710+
def argmax(self, /, axis: SupportsIndex | None = None, *, out: _BoolOrIntArrayT, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17101711

17111712
@overload # axis=None (default), out=None (default), keepdims=False (default)
17121713
def argmin(self, /, axis: None = None, out: None = None, *, keepdims: L[False] = False) -> intp: ...
17131714
@overload # axis=index, out=None (default)
17141715
def argmin(self, /, axis: SupportsIndex, out: None = None, *, keepdims: builtins.bool = False) -> Any: ...
17151716
@overload # axis=index, out=ndarray
1716-
def argmin(self, /, axis: SupportsIndex | None, out: _ArrayT, *, keepdims: builtins.bool = False) -> _ArrayT: ...
1717+
def argmin(self, /, axis: SupportsIndex | None, out: _BoolOrIntArrayT, *, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17171718
@overload
1718-
def argmin(self, /, axis: SupportsIndex | None = None, *, out: _ArrayT, keepdims: builtins.bool = False) -> _ArrayT: ...
1719+
def argmin(self, /, axis: SupportsIndex | None = None, *, out: _BoolOrIntArrayT, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17191720

17201721
@overload # out=None (default)
17211722
def round(self, /, decimals: SupportsIndex = 0, out: None = None) -> Self: ...
@@ -5364,19 +5365,19 @@ class matrix(ndarray[_2DShapeT_co, _DTypeT_co]):
53645365
@overload
53655366
def argmax(self, axis: _ShapeLike, out: None = None) -> matrix[_2D, dtype[intp]]: ...
53665367
@overload
5367-
def argmax(self, axis: _ShapeLike | None, out: _ArrayT) -> _ArrayT: ...
5368+
def argmax(self, axis: _ShapeLike | None, out: _BoolOrIntArrayT) -> _BoolOrIntArrayT: ...
53685369
@overload
5369-
def argmax(self, axis: _ShapeLike | None = None, *, out: _ArrayT) -> _ArrayT: ... # pyright: ignore[reportIncompatibleMethodOverride]
5370+
def argmax(self, axis: _ShapeLike | None = None, *, out: _BoolOrIntArrayT) -> _BoolOrIntArrayT: ... # pyright: ignore[reportIncompatibleMethodOverride]
53705371

53715372
# keep in sync with `argmax`
53725373
@overload # type: ignore[override]
53735374
def argmin(self: NDArray[_ScalarT], axis: None = None, out: None = None) -> intp: ...
53745375
@overload
53755376
def argmin(self, axis: _ShapeLike, out: None = None) -> matrix[_2D, dtype[intp]]: ...
53765377
@overload
5377-
def argmin(self, axis: _ShapeLike | None, out: _ArrayT) -> _ArrayT: ...
5378+
def argmin(self, axis: _ShapeLike | None, out: _BoolOrIntArrayT) -> _BoolOrIntArrayT: ...
53785379
@overload
5379-
def argmin(self, axis: _ShapeLike | None = None, *, out: _ArrayT) -> _ArrayT: ... # pyright: ignore[reportIncompatibleMethodOverride]
5380+
def argmin(self, axis: _ShapeLike | None = None, *, out: _BoolOrIntArrayT) -> _BoolOrIntArrayT: ... # pyright: ignore[reportIncompatibleMethodOverride]
53805381

53815382
#the second overload handles the (rare) case that the matrix is not 2-d
53825383
@overload

‎numpy/_core/fromnumeric.pyi

Copy file name to clipboardExpand all lines: numpy/_core/fromnumeric.pyi
+9-8Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ _NumberOrObjectT = TypeVar("_NumberOrObjectT", bound=np.number | np.object_)
111111
_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])
112112
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
113113
_ShapeT_co = TypeVar("_ShapeT_co", bound=tuple[int, ...], covariant=True)
114+
_BoolOrIntArrayT = TypeVar("_BoolOrIntArrayT", bound=NDArray[np.integer | np.bool])
114115

115116
@type_check_only
116117
class _SupportsShape(Protocol[_ShapeT_co]):
@@ -418,18 +419,18 @@ def argmax(
418419
def argmax(
419420
a: ArrayLike,
420421
axis: SupportsIndex | None,
421-
out: _ArrayT,
422+
out: _BoolOrIntArrayT,
422423
*,
423424
keepdims: bool = ...,
424-
) -> _ArrayT: ...
425+
) -> _BoolOrIntArrayT: ...
425426
@overload
426427
def argmax(
427428
a: ArrayLike,
428429
axis: SupportsIndex | None = ...,
429430
*,
430-
out: _ArrayT,
431+
out: _BoolOrIntArrayT,
431432
keepdims: bool = ...,
432-
) -> _ArrayT: ...
433+
) -> _BoolOrIntArrayT: ...
433434

434435
@overload
435436
def argmin(
@@ -451,18 +452,18 @@ def argmin(
451452
def argmin(
452453
a: ArrayLike,
453454
axis: SupportsIndex | None,
454-
out: _ArrayT,
455+
out: _BoolOrIntArrayT,
455456
*,
456457
keepdims: bool = ...,
457-
) -> _ArrayT: ...
458+
) -> _BoolOrIntArrayT: ...
458459
@overload
459460
def argmin(
460461
a: ArrayLike,
461462
axis: SupportsIndex | None = ...,
462463
*,
463-
out: _ArrayT,
464+
out: _BoolOrIntArrayT,
464465
keepdims: bool = ...,
465-
) -> _ArrayT: ...
466+
) -> _BoolOrIntArrayT: ...
466467

467468
@overload
468469
def searchsorted(

‎numpy/typing/tests/data/pass/ndarray_misc.py

Copy file name to clipboardExpand all lines: numpy/typing/tests/data/pass/ndarray_misc.py
+4-3Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
import numpy.typing as npt
1616

1717
class SubClass(npt.NDArray[np.float64]): ...
18-
18+
class IntSubClass(npt.NDArray[np.intp]): ...
1919

2020
i4 = np.int32(1)
2121
A: np.ndarray[Any, np.dtype[np.int32]] = np.array([[1]], dtype=np.int32)
2222
B0 = np.empty((), dtype=np.int32).view(SubClass)
2323
B1 = np.empty((1,), dtype=np.int32).view(SubClass)
2424
B2 = np.empty((1, 1), dtype=np.int32).view(SubClass)
25+
B_int0: IntSubClass = np.empty((), dtype=np.intp).view(IntSubClass)
2526
C: np.ndarray[Any, np.dtype[np.int32]] = np.array([0, 1, 2], dtype=np.int32)
2627
D = np.ones(3).view(SubClass)
2728

@@ -42,12 +43,12 @@ class SubClass(npt.NDArray[np.float64]): ...
4243
i4.argmax()
4344
A.argmax()
4445
A.argmax(axis=0)
45-
A.argmax(out=B0)
46+
A.argmax(out=B_int0)
4647

4748
i4.argmin()
4849
A.argmin()
4950
A.argmin(axis=0)
50-
A.argmin(out=B0)
51+
A.argmin(out=B_int0)
5152

5253
i4.argsort()
5354
A.argsort()

‎numpy/typing/tests/data/reveal/fromnumeric.pyi

Copy file name to clipboardExpand all lines: numpy/typing/tests/data/reveal/fromnumeric.pyi
+6-2Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ f4: np.float32
2525
i8: np.int64
2626
f: float
2727

28+
# integer‑dtype subclass for argmin/argmax
29+
class NDArrayIntSubclass(npt.NDArray[np.intp]): ...
30+
AR_sub_i: NDArrayIntSubclass
31+
2832
assert_type(np.take(b, 0), np.bool)
2933
assert_type(np.take(f4, 0), np.float32)
3034
assert_type(np.take(f, 0), Any)
@@ -89,13 +93,13 @@ assert_type(np.argmax(AR_b), np.intp)
8993
assert_type(np.argmax(AR_f4), np.intp)
9094
assert_type(np.argmax(AR_b, axis=0), Any)
9195
assert_type(np.argmax(AR_f4, axis=0), Any)
92-
assert_type(np.argmax(AR_f4, out=AR_subclass), NDArraySubclass)
96+
assert_type(np.argmax(AR_f4, out=AR_sub_i), NDArrayIntSubclass)
9397

9498
assert_type(np.argmin(AR_b), np.intp)
9599
assert_type(np.argmin(AR_f4), np.intp)
96100
assert_type(np.argmin(AR_b, axis=0), Any)
97101
assert_type(np.argmin(AR_f4, axis=0), Any)
98-
assert_type(np.argmin(AR_f4, out=AR_subclass), NDArraySubclass)
102+
assert_type(np.argmin(AR_f4, out=AR_sub_i), NDArrayIntSubclass)
99103

100104
assert_type(np.searchsorted(AR_b[0], 0), np.intp)
101105
assert_type(np.searchsorted(AR_f4[0], 0), np.intp)

‎numpy/typing/tests/data/reveal/matrix.pyi

Copy file name to clipboardExpand all lines: numpy/typing/tests/data/reveal/matrix.pyi
+3-2Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ _Shape2D: TypeAlias = tuple[int, int]
77

88
mat: np.matrix[_Shape2D, np.dtype[np.int64]]
99
ar_f8: npt.NDArray[np.float64]
10+
ar_ip: npt.NDArray[np.intp]
1011

1112
assert_type(mat * 5, np.matrix[_Shape2D, Any])
1213
assert_type(5 * mat, np.matrix[_Shape2D, Any])
@@ -50,8 +51,8 @@ assert_type(mat.any(out=ar_f8), npt.NDArray[np.float64])
5051
assert_type(mat.all(out=ar_f8), npt.NDArray[np.float64])
5152
assert_type(mat.max(out=ar_f8), npt.NDArray[np.float64])
5253
assert_type(mat.min(out=ar_f8), npt.NDArray[np.float64])
53-
assert_type(mat.argmax(out=ar_f8), npt.NDArray[np.float64])
54-
assert_type(mat.argmin(out=ar_f8), npt.NDArray[np.float64])
54+
assert_type(mat.argmax(out=ar_ip), npt.NDArray[np.intp])
55+
assert_type(mat.argmin(out=ar_ip), npt.NDArray[np.intp])
5556
assert_type(mat.ptp(out=ar_f8), npt.NDArray[np.float64])
5657

5758
assert_type(mat.T, np.matrix[_Shape2D, np.dtype[np.int64]])

‎numpy/typing/tests/data/reveal/ndarray_misc.pyi

Copy file name to clipboardExpand all lines: numpy/typing/tests/data/reveal/ndarray_misc.pyi
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ assert_type(AR_f8.any(out=B), SubClass)
5858
assert_type(f8.argmax(), np.intp)
5959
assert_type(AR_f8.argmax(), np.intp)
6060
assert_type(AR_f8.argmax(axis=0), Any)
61-
assert_type(AR_f8.argmax(out=B), SubClass)
61+
assert_type(AR_f8.argmax(out=AR_i8), npt.NDArray[np.intp])
6262

6363
assert_type(f8.argmin(), np.intp)
6464
assert_type(AR_f8.argmin(), np.intp)
6565
assert_type(AR_f8.argmin(axis=0), Any)
66-
assert_type(AR_f8.argmin(out=B), SubClass)
66+
assert_type(AR_f8.argmin(out=AR_i8), npt.NDArray[np.intp])
6767

6868
assert_type(f8.argsort(), npt.NDArray[Any])
6969
assert_type(AR_f8.argsort(), npt.NDArray[Any])

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.