Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3957025

Browse filesBrowse files
Use checkmember.py to check protocol subtyping (#18943)
Fixes #18024 Fixes #18706 Fixes #17734 Fixes #15097 Fixes #14814 Fixes #14806 Fixes #14259 Fixes #13041 Fixes #11993 Fixes #9585 Fixes #9266 Fixes #9202 Fixes #5481 This is a fourth "major" PR toward #7724. This is one is watershed/crux of the whole series (but to set correct expectations, there are almost a dozen smaller follow-up/clean-up PRs in the pipeline). The core of the idea is to set current type-checker as part of the global state. There are however some details: * There are cases where we call `is_subtype()` before type-checking. For now, I fall back to old logic in this cases. In follow up PRs we may switch to using type-checker instances before type checking phase (this requires some care). * This increases typeops import cycle by a few modules, but unfortunately this is inevitable. * This PR increases potential for infinite recursion in protocols. To mitigate I add: one legitimate fix for `__call__`, and one temporary hack for `freshen_all_functions_type_vars` (to reduce performance impact). * Finally I change semantics for method access on class objects to match the one in old `find_member()`. Now we will expand type by instance, so we have something like this: ```python class B(Generic[T]): def foo(self, x: T) -> T: ... class C(B[str]): ... reveal_type(C.foo) # def (self: B[str], x: str) -> str ``` FWIW, I am not even 100% sure this is correct, it seems to me we _may_ keep the method generic. But in any case what we do currently is definitely wrong (we infer a _non-generic_ `def (x: T) -> T`). --------- Co-authored-by: hauntsaninja <hauntsaninja@gmail.com> Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
1 parent 7a32bc1 commit 3957025
Copy full SHA for 3957025

File tree

Expand file treeCollapse file tree

11 files changed

+213
-45
lines changed
Filter options
Expand file treeCollapse file tree

11 files changed

+213
-45
lines changed

‎.github/workflows/mypy_primer.yml

Copy file name to clipboardExpand all lines: .github/workflows/mypy_primer.yml
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ jobs:
6767
--debug \
6868
--additional-flags="--debug-serialize" \
6969
--output concise \
70+
--show-speed-regression \
7071
| tee diff_${{ matrix.shard-index }}.txt
7172
) || [ $? -eq 1 ]
7273
- if: ${{ matrix.shard-index == 0 }}

‎mypy/checker.py

Copy file name to clipboardExpand all lines: mypy/checker.py
+3-2Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mypy import errorcodes as codes, join, message_registry, nodes, operators
1414
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
1515
from mypy.checker_shared import CheckerScope, TypeCheckerSharedApi, TypeRange
16+
from mypy.checker_state import checker_state
1617
from mypy.checkmember import (
1718
MemberContext,
1819
analyze_class_attribute_access,
@@ -453,7 +454,7 @@ def check_first_pass(self) -> None:
453454
Deferred functions will be processed by check_second_pass().
454455
"""
455456
self.recurse_into_functions = True
456-
with state.strict_optional_set(self.options.strict_optional):
457+
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
457458
self.errors.set_file(
458459
self.path, self.tree.fullname, scope=self.tscope, options=self.options
459460
)
@@ -494,7 +495,7 @@ def check_second_pass(
494495
This goes through deferred nodes, returning True if there were any.
495496
"""
496497
self.recurse_into_functions = True
497-
with state.strict_optional_set(self.options.strict_optional):
498+
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
498499
if not todo and not self.deferred_nodes:
499500
return False
500501
self.errors.set_file(

‎mypy/checker_state.py

Copy file name to clipboard
+30Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterator
4+
from contextlib import contextmanager
5+
from typing import Final
6+
7+
from mypy.checker_shared import TypeCheckerSharedApi
8+
9+
# This is global mutable state. Don't add anything here unless there's a very
10+
# good reason.
11+
12+
13+
class TypeCheckerState:
14+
# Wrap this in a class since it's faster that using a module-level attribute.
15+
16+
def __init__(self, type_checker: TypeCheckerSharedApi | None) -> None:
17+
# Value varies by file being processed
18+
self.type_checker = type_checker
19+
20+
@contextmanager
21+
def set(self, value: TypeCheckerSharedApi) -> Iterator[None]:
22+
saved = self.type_checker
23+
self.type_checker = value
24+
try:
25+
yield
26+
finally:
27+
self.type_checker = saved
28+
29+
30+
checker_state: Final = TypeCheckerState(type_checker=None)

‎mypy/checkmember.py

Copy file name to clipboardExpand all lines: mypy/checkmember.py
+25-31Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
is_self: bool = False,
9898
rvalue: Expression | None = None,
9999
suppress_errors: bool = False,
100+
preserve_type_var_ids: bool = False,
100101
) -> None:
101102
self.is_lvalue = is_lvalue
102103
self.is_super = is_super
@@ -113,6 +114,10 @@ def __init__(
113114
assert is_lvalue
114115
self.rvalue = rvalue
115116
self.suppress_errors = suppress_errors
117+
# This attribute is only used to preserve old protocol member access logic.
118+
# It is needed to avoid infinite recursion in cases involving self-referential
119+
# generic methods, see find_member() for details. Do not use for other purposes!
120+
self.preserve_type_var_ids = preserve_type_var_ids
116121

117122
def named_type(self, name: str) -> Instance:
118123
return self.chk.named_type(name)
@@ -143,6 +148,7 @@ def copy_modified(
143148
no_deferral=self.no_deferral,
144149
rvalue=self.rvalue,
145150
suppress_errors=self.suppress_errors,
151+
preserve_type_var_ids=self.preserve_type_var_ids,
146152
)
147153
if self_type is not None:
148154
mx.self_type = self_type
@@ -232,8 +238,6 @@ def analyze_member_access(
232238
def _analyze_member_access(
233239
name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None
234240
) -> Type:
235-
# TODO: This and following functions share some logic with subtypes.find_member;
236-
# consider refactoring.
237241
typ = get_proper_type(typ)
238242
if isinstance(typ, Instance):
239243
return analyze_instance_member_access(name, typ, mx, override_info)
@@ -358,7 +362,8 @@ def analyze_instance_member_access(
358362
return AnyType(TypeOfAny.special_form)
359363
assert isinstance(method.type, Overloaded)
360364
signature = method.type
361-
signature = freshen_all_functions_type_vars(signature)
365+
if not mx.preserve_type_var_ids:
366+
signature = freshen_all_functions_type_vars(signature)
362367
if not method.is_static:
363368
if isinstance(method, (FuncDef, OverloadedFuncDef)) and method.is_trivial_self:
364369
signature = bind_self_fast(signature, mx.self_type)
@@ -943,7 +948,8 @@ def analyze_var(
943948
def expand_without_binding(
944949
typ: Type, var: Var, itype: Instance, original_itype: Instance, mx: MemberContext
945950
) -> Type:
946-
typ = freshen_all_functions_type_vars(typ)
951+
if not mx.preserve_type_var_ids:
952+
typ = freshen_all_functions_type_vars(typ)
947953
typ = expand_self_type_if_needed(typ, mx, var, original_itype)
948954
expanded = expand_type_by_instance(typ, itype)
949955
freeze_all_type_vars(expanded)
@@ -958,7 +964,8 @@ def expand_and_bind_callable(
958964
mx: MemberContext,
959965
is_trivial_self: bool,
960966
) -> Type:
961-
functype = freshen_all_functions_type_vars(functype)
967+
if not mx.preserve_type_var_ids:
968+
functype = freshen_all_functions_type_vars(functype)
962969
typ = get_proper_type(expand_self_type(var, functype, mx.original_type))
963970
assert isinstance(typ, FunctionLike)
964971
if is_trivial_self:
@@ -1056,10 +1063,12 @@ def f(self: S) -> T: ...
10561063
return functype
10571064
else:
10581065
selfarg = get_proper_type(item.arg_types[0])
1059-
# This level of erasure matches the one in checker.check_func_def(),
1060-
# better keep these two checks consistent.
1061-
if subtypes.is_subtype(
1066+
# This matches similar special-casing in bind_self(), see more details there.
1067+
self_callable = name == "__call__" and isinstance(selfarg, CallableType)
1068+
if self_callable or subtypes.is_subtype(
10621069
dispatched_arg_type,
1070+
# This level of erasure matches the one in checker.check_func_def(),
1071+
# better keep these two checks consistent.
10631072
erase_typevars(erase_to_bound(selfarg)),
10641073
# This is to work around the fact that erased ParamSpec and TypeVarTuple
10651074
# callables are not always compatible with non-erased ones both ways.
@@ -1220,9 +1229,6 @@ def analyze_class_attribute_access(
12201229
is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or (
12211230
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_class
12221231
)
1223-
is_staticmethod = (is_decorated and cast(Decorator, node.node).func.is_static) or (
1224-
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_static
1225-
)
12261232
t = get_proper_type(t)
12271233
is_trivial_self = False
12281234
if isinstance(node.node, Decorator):
@@ -1236,8 +1242,7 @@ def analyze_class_attribute_access(
12361242
t,
12371243
isuper,
12381244
is_classmethod,
1239-
is_staticmethod,
1240-
mx.self_type,
1245+
mx,
12411246
original_vars=original_vars,
12421247
is_trivial_self=is_trivial_self,
12431248
)
@@ -1372,8 +1377,7 @@ def add_class_tvars(
13721377
t: ProperType,
13731378
isuper: Instance | None,
13741379
is_classmethod: bool,
1375-
is_staticmethod: bool,
1376-
original_type: Type,
1380+
mx: MemberContext,
13771381
original_vars: Sequence[TypeVarLikeType] | None = None,
13781382
is_trivial_self: bool = False,
13791383
) -> Type:
@@ -1392,9 +1396,6 @@ class B(A[str]): pass
13921396
isuper: Current instance mapped to the superclass where method was defined, this
13931397
is usually done by map_instance_to_supertype()
13941398
is_classmethod: True if this method is decorated with @classmethod
1395-
is_staticmethod: True if this method is decorated with @staticmethod
1396-
original_type: The value of the type B in the expression B.foo() or the corresponding
1397-
component in case of a union (this is used to bind the self-types)
13981399
original_vars: Type variables of the class callable on which the method was accessed
13991400
is_trivial_self: if True, we can use fast path for bind_self().
14001401
Returns:
@@ -1416,14 +1417,14 @@ class B(A[str]): pass
14161417
# (i.e. appear in the return type of the class object on which the method was accessed).
14171418
if isinstance(t, CallableType):
14181419
tvars = original_vars if original_vars is not None else []
1419-
t = freshen_all_functions_type_vars(t)
1420+
if not mx.preserve_type_var_ids:
1421+
t = freshen_all_functions_type_vars(t)
14201422
if is_classmethod:
14211423
if is_trivial_self:
1422-
t = bind_self_fast(t, original_type)
1424+
t = bind_self_fast(t, mx.self_type)
14231425
else:
1424-
t = bind_self(t, original_type, is_classmethod=True)
1425-
if is_classmethod or is_staticmethod:
1426-
assert isuper is not None
1426+
t = bind_self(t, mx.self_type, is_classmethod=True)
1427+
if isuper is not None:
14271428
t = expand_type_by_instance(t, isuper)
14281429
freeze_all_type_vars(t)
14291430
return t.copy_modified(variables=list(tvars) + list(t.variables))
@@ -1432,14 +1433,7 @@ class B(A[str]): pass
14321433
[
14331434
cast(
14341435
CallableType,
1435-
add_class_tvars(
1436-
item,
1437-
isuper,
1438-
is_classmethod,
1439-
is_staticmethod,
1440-
original_type,
1441-
original_vars=original_vars,
1442-
),
1436+
add_class_tvars(item, isuper, is_classmethod, mx, original_vars=original_vars),
14431437
)
14441438
for item in t.items
14451439
]

‎mypy/messages.py

Copy file name to clipboardExpand all lines: mypy/messages.py
+7-2Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,8 +2220,13 @@ def report_protocol_problems(
22202220
exp = get_proper_type(exp)
22212221
got = get_proper_type(got)
22222222
setter_suffix = " setter type" if is_lvalue else ""
2223-
if not isinstance(exp, (CallableType, Overloaded)) or not isinstance(
2224-
got, (CallableType, Overloaded)
2223+
if (
2224+
not isinstance(exp, (CallableType, Overloaded))
2225+
or not isinstance(got, (CallableType, Overloaded))
2226+
# If expected type is a type object, it means it is a nested class.
2227+
# Showing constructor signature in errors would be confusing in this case,
2228+
# since we don't check the signature, only subclassing of type objects.
2229+
or exp.is_type_obj()
22252230
):
22262231
self.note(
22272232
"{}: expected{} {}, got {}".format(

‎mypy/plugin.py

Copy file name to clipboardExpand all lines: mypy/plugin.py
+5-3Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,13 @@ class C: pass
119119
from __future__ import annotations
120120

121121
from abc import abstractmethod
122-
from typing import Any, Callable, NamedTuple, TypeVar
122+
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar
123123

124124
from mypy_extensions import mypyc_attr, trait
125125

126126
from mypy.errorcodes import ErrorCode
127127
from mypy.lookup import lookup_fully_qualified
128128
from mypy.message_registry import ErrorMessage
129-
from mypy.messages import MessageBuilder
130129
from mypy.nodes import (
131130
ArgKind,
132131
CallExpr,
@@ -138,7 +137,6 @@ class C: pass
138137
TypeInfo,
139138
)
140139
from mypy.options import Options
141-
from mypy.tvar_scope import TypeVarLikeScope
142140
from mypy.types import (
143141
CallableType,
144142
FunctionLike,
@@ -149,6 +147,10 @@ class C: pass
149147
UnboundType,
150148
)
151149

150+
if TYPE_CHECKING:
151+
from mypy.messages import MessageBuilder
152+
from mypy.tvar_scope import TypeVarLikeScope
153+
152154

153155
@trait
154156
class TypeAnalyzerPluginInterface:

‎mypy/subtypes.py

Copy file name to clipboardExpand all lines: mypy/subtypes.py
+79-5Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import mypy.applytype
99
import mypy.constraints
1010
import mypy.typeops
11+
from mypy.checker_state import checker_state
1112
from mypy.erasetype import erase_type
1213
from mypy.expandtype import (
1314
expand_self_type,
@@ -26,6 +27,7 @@
2627
COVARIANT,
2728
INVARIANT,
2829
VARIANCE_NOT_READY,
30+
Context,
2931
Decorator,
3032
FuncBase,
3133
OverloadedFuncDef,
@@ -717,8 +719,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
717719
elif isinstance(right, Instance):
718720
if right.type.is_protocol and "__call__" in right.type.protocol_members:
719721
# OK, a callable can implement a protocol with a `__call__` member.
720-
# TODO: we should probably explicitly exclude self-types in this case.
721-
call = find_member("__call__", right, left, is_operator=True)
722+
call = find_member("__call__", right, right, is_operator=True)
722723
assert call is not None
723724
if self._is_subtype(left, call):
724725
if len(right.type.protocol_members) == 1:
@@ -954,7 +955,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
954955
if isinstance(right, Instance):
955956
if right.type.is_protocol and "__call__" in right.type.protocol_members:
956957
# same as for CallableType
957-
call = find_member("__call__", right, left, is_operator=True)
958+
call = find_member("__call__", right, right, is_operator=True)
958959
assert call is not None
959960
if self._is_subtype(left, call):
960961
if len(right.type.protocol_members) == 1:
@@ -1266,14 +1267,87 @@ def find_member(
12661267
is_operator: bool = False,
12671268
class_obj: bool = False,
12681269
is_lvalue: bool = False,
1270+
) -> Type | None:
1271+
type_checker = checker_state.type_checker
1272+
if type_checker is None:
1273+
# Unfortunately, there are many scenarios where someone calls is_subtype() before
1274+
# type checking phase. In this case we fallback to old (incomplete) logic.
1275+
# TODO: reduce number of such cases (e.g. semanal_typeargs, post-semanal plugins).
1276+
return find_member_simple(
1277+
name, itype, subtype, is_operator=is_operator, class_obj=class_obj, is_lvalue=is_lvalue
1278+
)
1279+
1280+
# We don't use ATTR_DEFINED error code below (since missing attributes can cause various
1281+
# other error codes), instead we perform quick node lookup with all the fallbacks.
1282+
info = itype.type
1283+
sym = info.get(name)
1284+
node = sym.node if sym else None
1285+
if not node:
1286+
name_not_found = True
1287+
if (
1288+
name not in ["__getattr__", "__setattr__", "__getattribute__"]
1289+
and not is_operator
1290+
and not class_obj
1291+
and itype.extra_attrs is None # skip ModuleType.__getattr__
1292+
):
1293+
for method_name in ("__getattribute__", "__getattr__"):
1294+
method = info.get_method(method_name)
1295+
if method and method.info.fullname != "builtins.object":
1296+
name_not_found = False
1297+
break
1298+
if name_not_found:
1299+
if info.fallback_to_any or class_obj and info.meta_fallback_to_any:
1300+
return AnyType(TypeOfAny.special_form)
1301+
if itype.extra_attrs and name in itype.extra_attrs.attrs:
1302+
return itype.extra_attrs.attrs[name]
1303+
return None
1304+
1305+
from mypy.checkmember import (
1306+
MemberContext,
1307+
analyze_class_attribute_access,
1308+
analyze_instance_member_access,
1309+
)
1310+
1311+
mx = MemberContext(
1312+
is_lvalue=is_lvalue,
1313+
is_super=False,
1314+
is_operator=is_operator,
1315+
original_type=itype,
1316+
self_type=subtype,
1317+
context=Context(), # all errors are filtered, but this is a required argument
1318+
chk=type_checker,
1319+
suppress_errors=True,
1320+
# This is needed to avoid infinite recursion in situations involving protocols like
1321+
# class P(Protocol[T]):
1322+
# def combine(self, other: P[S]) -> P[Tuple[T, S]]: ...
1323+
# Normally we call freshen_all_functions_type_vars() during attribute access,
1324+
# to avoid type variable id collisions, but for protocols this means we can't
1325+
# use the assumption stack, that will grow indefinitely.
1326+
# TODO: find a cleaner solution that doesn't involve massive perf impact.
1327+
preserve_type_var_ids=True,
1328+
)
1329+
with type_checker.msg.filter_errors(filter_deprecated=True):
1330+
if class_obj:
1331+
fallback = itype.type.metaclass_type or mx.named_type("builtins.type")
1332+
return analyze_class_attribute_access(itype, name, mx, mcs_fallback=fallback)
1333+
else:
1334+
return analyze_instance_member_access(name, itype, mx, info)
1335+
1336+
1337+
def find_member_simple(
1338+
name: str,
1339+
itype: Instance,
1340+
subtype: Type,
1341+
*,
1342+
is_operator: bool = False,
1343+
class_obj: bool = False,
1344+
is_lvalue: bool = False,
12691345
) -> Type | None:
12701346
"""Find the type of member by 'name' in 'itype's TypeInfo.
12711347
12721348
Find the member type after applying type arguments from 'itype', and binding
12731349
'self' to 'subtype'. Return None if member was not found.
12741350
"""
1275-
# TODO: this code shares some logic with checkmember.analyze_member_access,
1276-
# consider refactoring.
12771351
info = itype.type
12781352
method = info.get_method(name)
12791353
if method:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.