Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit cc2a3cf

Browse filesBrowse files
jorenhamArvidJB
authored andcommitted
TYP: Improved numpy.generic rich comparison operator type annotations.
1 parent 5ae655d commit cc2a3cf
Copy full SHA for cc2a3cf

File tree

Expand file treeCollapse file tree

3 files changed

+78
-30
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+78
-30
lines changed

‎numpy/__init__.pyi

Copy file name to clipboardExpand all lines: numpy/__init__.pyi
+20-17Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,10 @@ from numpy._typing._callable import (
151151
_FloatDivMod,
152152
_ComplexOp,
153153
_NumberOp,
154-
_ComparisonOp,
154+
_ComparisonOpLT,
155+
_ComparisonOpLE,
156+
_ComparisonOpGT,
157+
_ComparisonOpGE,
155158
)
156159

157160
# NOTE: Numpy's mypy plugin is used for removing the types unavailable
@@ -2800,10 +2803,10 @@ class number(generic, Generic[_NBit1]): # type: ignore
28002803
__rpow__: _NumberOp
28012804
__truediv__: _NumberOp
28022805
__rtruediv__: _NumberOp
2803-
__lt__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
2804-
__le__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
2805-
__gt__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
2806-
__ge__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
2806+
__lt__: _ComparisonOpLT[_NumberLike_co, _ArrayLikeNumber_co]
2807+
__le__: _ComparisonOpLE[_NumberLike_co, _ArrayLikeNumber_co]
2808+
__gt__: _ComparisonOpGT[_NumberLike_co, _ArrayLikeNumber_co]
2809+
__ge__: _ComparisonOpGE[_NumberLike_co, _ArrayLikeNumber_co]
28072810

28082811
class bool(generic):
28092812
def __init__(self, value: object = ..., /) -> None: ...
@@ -2846,10 +2849,10 @@ class bool(generic):
28462849
__rmod__: _BoolMod
28472850
__divmod__: _BoolDivMod
28482851
__rdivmod__: _BoolDivMod
2849-
__lt__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
2850-
__le__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
2851-
__gt__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
2852-
__ge__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
2852+
__lt__: _ComparisonOpLT[_NumberLike_co, _ArrayLikeNumber_co]
2853+
__le__: _ComparisonOpLE[_NumberLike_co, _ArrayLikeNumber_co]
2854+
__gt__: _ComparisonOpGT[_NumberLike_co, _ArrayLikeNumber_co]
2855+
__ge__: _ComparisonOpGE[_NumberLike_co, _ArrayLikeNumber_co]
28532856

28542857
bool_: TypeAlias = bool
28552858

@@ -2902,10 +2905,10 @@ class datetime64(generic):
29022905
@overload
29032906
def __sub__(self, other: _TD64Like_co, /) -> datetime64: ...
29042907
def __rsub__(self, other: datetime64, /) -> timedelta64: ...
2905-
__lt__: _ComparisonOp[datetime64, _ArrayLikeDT64_co]
2906-
__le__: _ComparisonOp[datetime64, _ArrayLikeDT64_co]
2907-
__gt__: _ComparisonOp[datetime64, _ArrayLikeDT64_co]
2908-
__ge__: _ComparisonOp[datetime64, _ArrayLikeDT64_co]
2908+
__lt__: _ComparisonOpLT[datetime64, _ArrayLikeDT64_co]
2909+
__le__: _ComparisonOpLE[datetime64, _ArrayLikeDT64_co]
2910+
__gt__: _ComparisonOpGT[datetime64, _ArrayLikeDT64_co]
2911+
__ge__: _ComparisonOpGE[datetime64, _ArrayLikeDT64_co]
29092912

29102913
_IntValue: TypeAlias = SupportsInt | _CharLike_co | SupportsIndex
29112914
_FloatValue: TypeAlias = None | _CharLike_co | SupportsFloat | SupportsIndex
@@ -3030,10 +3033,10 @@ class timedelta64(generic):
30303033
def __rmod__(self, other: timedelta64, /) -> timedelta64: ...
30313034
def __divmod__(self, other: timedelta64, /) -> tuple[int64, timedelta64]: ...
30323035
def __rdivmod__(self, other: timedelta64, /) -> tuple[int64, timedelta64]: ...
3033-
__lt__: _ComparisonOp[_TD64Like_co, _ArrayLikeTD64_co]
3034-
__le__: _ComparisonOp[_TD64Like_co, _ArrayLikeTD64_co]
3035-
__gt__: _ComparisonOp[_TD64Like_co, _ArrayLikeTD64_co]
3036-
__ge__: _ComparisonOp[_TD64Like_co, _ArrayLikeTD64_co]
3036+
__lt__: _ComparisonOpLT[_TD64Like_co, _ArrayLikeTD64_co]
3037+
__le__: _ComparisonOpLE[_TD64Like_co, _ArrayLikeTD64_co]
3038+
__gt__: _ComparisonOpGT[_TD64Like_co, _ArrayLikeTD64_co]
3039+
__ge__: _ComparisonOpGE[_TD64Like_co, _ArrayLikeTD64_co]
30373040

30383041
class unsignedinteger(integer[_NBit1]):
30393042
# NOTE: `uint64 + signedinteger -> float64`

‎numpy/_typing/_callable.pyi

Copy file name to clipboardExpand all lines: numpy/_typing/_callable.pyi
+54-9Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ See the `Mypy documentation`_ on protocols for more details.
1111
from __future__ import annotations
1212

1313
from typing import (
14+
TypeAlias,
1415
TypeVar,
16+
final,
1517
overload,
1618
Any,
1719
NoReturn,
@@ -48,7 +50,8 @@ _T1 = TypeVar("_T1")
4850
_T2 = TypeVar("_T2")
4951
_T1_contra = TypeVar("_T1_contra", contravariant=True)
5052
_T2_contra = TypeVar("_T2_contra", contravariant=True)
51-
_2Tuple = tuple[_T1, _T1]
53+
54+
_2Tuple: TypeAlias = tuple[_T1, _T1]
5255

5356
_NBit1 = TypeVar("_NBit1", bound=NBitBase)
5457
_NBit2 = TypeVar("_NBit2", bound=NBitBase)
@@ -317,20 +320,62 @@ class _ComplexOp(Protocol[_NBit1]):
317320
class _NumberOp(Protocol):
318321
def __call__(self, other: _NumberLike_co, /) -> Any: ...
319322

323+
@final
320324
class _SupportsLT(Protocol):
321-
def __lt__(self, other: Any, /) -> object: ...
325+
def __lt__(self, other: Any, /) -> Any: ...
326+
327+
@final
328+
class _SupportsLE(Protocol):
329+
def __le__(self, other: Any, /) -> Any: ...
322330

331+
@final
323332
class _SupportsGT(Protocol):
324-
def __gt__(self, other: Any, /) -> object: ...
333+
def __gt__(self, other: Any, /) -> Any: ...
325334

326-
class _ComparisonOp(Protocol[_T1_contra, _T2_contra]):
335+
@final
336+
class _SupportsGE(Protocol):
337+
def __ge__(self, other: Any, /) -> Any: ...
338+
339+
@final
340+
class _ComparisonOpLT(Protocol[_T1_contra, _T2_contra]):
327341
@overload
328342
def __call__(self, other: _T1_contra, /) -> np.bool: ...
329343
@overload
330344
def __call__(self, other: _T2_contra, /) -> NDArray[np.bool]: ...
331345
@overload
332-
def __call__(
333-
self,
334-
other: _SupportsLT | _SupportsGT | _NestedSequence[_SupportsLT | _SupportsGT],
335-
/,
336-
) -> Any: ...
346+
def __call__(self, other: _NestedSequence[_SupportsGT], /) -> NDArray[np.bool]: ...
347+
@overload
348+
def __call__(self, other: _SupportsGT, /) -> np.bool: ...
349+
350+
@final
351+
class _ComparisonOpLE(Protocol[_T1_contra, _T2_contra]):
352+
@overload
353+
def __call__(self, other: _T1_contra, /) -> np.bool: ...
354+
@overload
355+
def __call__(self, other: _T2_contra, /) -> NDArray[np.bool]: ...
356+
@overload
357+
def __call__(self, other: _NestedSequence[_SupportsGE], /) -> NDArray[np.bool]: ...
358+
@overload
359+
def __call__(self, other: _SupportsGE, /) -> np.bool: ...
360+
361+
@final
362+
class _ComparisonOpGT(Protocol[_T1_contra, _T2_contra]):
363+
@overload
364+
def __call__(self, other: _T1_contra, /) -> np.bool: ...
365+
@overload
366+
def __call__(self, other: _T2_contra, /) -> NDArray[np.bool]: ...
367+
@overload
368+
def __call__(self, other: _NestedSequence[_SupportsLT], /) -> NDArray[np.bool]: ...
369+
@overload
370+
def __call__(self, other: _SupportsLT, /) -> np.bool: ...
371+
372+
@final
373+
class _ComparisonOpGE(Protocol[_T1_contra, _T2_contra]):
374+
@overload
375+
def __call__(self, other: _T1_contra, /) -> np.bool: ...
376+
@overload
377+
def __call__(self, other: _T2_contra, /) -> NDArray[np.bool]: ...
378+
@overload
379+
def __call__(self, other: _NestedSequence[_SupportsGT], /) -> NDArray[np.bool]: ...
380+
@overload
381+
def __call__(self, other: _SupportsGT, /) -> np.bool: ...

‎numpy/typing/tests/data/reveal/comparisons.pyi

Copy file name to clipboardExpand all lines: numpy/typing/tests/data/reveal/comparisons.pyi
+4-4Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ SEQ = (0, 1, 2, 3, 4)
3838

3939
# object-like comparisons
4040

41-
assert_type(i8 > fractions.Fraction(1, 5), Any)
42-
assert_type(i8 > [fractions.Fraction(1, 5)], Any)
43-
assert_type(i8 > decimal.Decimal("1.5"), Any)
44-
assert_type(i8 > [decimal.Decimal("1.5")], Any)
41+
assert_type(i8 > fractions.Fraction(1, 5), np.bool)
42+
assert_type(i8 > [fractions.Fraction(1, 5)], npt.NDArray[np.bool])
43+
assert_type(i8 > decimal.Decimal("1.5"), np.bool)
44+
assert_type(i8 > [decimal.Decimal("1.5")], npt.NDArray[np.bool])
4545

4646
# Time structures
4747

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.