Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c8f233c

Browse filesBrowse files
gh-132805: annotationlib: Fix handling of non-constant values in FORWARDREF (#132812)
Co-authored-by: David C Ellis <ducksual@gmail.com>
1 parent 7cb86c5 commit c8f233c
Copy full SHA for c8f233c

File tree

3 files changed

+250
-43
lines changed
Filter options

3 files changed

+250
-43
lines changed

‎Lib/annotationlib.py

Copy file name to clipboardExpand all lines: Lib/annotationlib.py
+132-43Lines changed: 132 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Format(enum.IntEnum):
3838
"__weakref__",
3939
"__arg__",
4040
"__globals__",
41+
"__extra_names__",
4142
"__code__",
4243
"__ast_node__",
4344
"__cell__",
@@ -82,6 +83,7 @@ def __init__(
8283
# is created through __class__ assignment on a _Stringifier object.
8384
self.__globals__ = None
8485
self.__cell__ = None
86+
self.__extra_names__ = None
8587
# These are initially None but serve as a cache and may be set to a non-None
8688
# value later.
8789
self.__code__ = None
@@ -151,6 +153,8 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
151153
if not self.__forward_is_class__ or param_name not in globals:
152154
globals[param_name] = param
153155
locals.pop(param_name, None)
156+
if self.__extra_names__:
157+
locals = {**locals, **self.__extra_names__}
154158

155159
arg = self.__forward_arg__
156160
if arg.isidentifier() and not keyword.iskeyword(arg):
@@ -231,6 +235,10 @@ def __eq__(self, other):
231235
and self.__forward_is_class__ == other.__forward_is_class__
232236
and self.__cell__ == other.__cell__
233237
and self.__owner__ == other.__owner__
238+
and (
239+
(tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None) ==
240+
(tuple(sorted(other.__extra_names__.items())) if other.__extra_names__ else None)
241+
)
234242
)
235243

236244
def __hash__(self):
@@ -241,6 +249,7 @@ def __hash__(self):
241249
self.__forward_is_class__,
242250
self.__cell__,
243251
self.__owner__,
252+
tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None,
244253
))
245254

246255
def __or__(self, other):
@@ -274,6 +283,7 @@ def __init__(
274283
cell=None,
275284
*,
276285
stringifier_dict,
286+
extra_names=None,
277287
):
278288
# Either an AST node or a simple str (for the common case where a ForwardRef
279289
# represent a single name).
@@ -285,49 +295,91 @@ def __init__(
285295
self.__code__ = None
286296
self.__ast_node__ = node
287297
self.__globals__ = globals
298+
self.__extra_names__ = extra_names
288299
self.__cell__ = cell
289300
self.__owner__ = owner
290301
self.__stringifier_dict__ = stringifier_dict
291302

292303
def __convert_to_ast(self, other):
293304
if isinstance(other, _Stringifier):
294305
if isinstance(other.__ast_node__, str):
295-
return ast.Name(id=other.__ast_node__)
296-
return other.__ast_node__
297-
elif isinstance(other, slice):
306+
return ast.Name(id=other.__ast_node__), other.__extra_names__
307+
return other.__ast_node__, other.__extra_names__
308+
elif (
309+
# In STRING format we don't bother with the create_unique_name() dance;
310+
# it's better to emit the repr() of the object instead of an opaque name.
311+
self.__stringifier_dict__.format == Format.STRING
312+
or other is None
313+
or type(other) in (str, int, float, bool, complex)
314+
):
315+
return ast.Constant(value=other), None
316+
elif type(other) is dict:
317+
extra_names = {}
318+
keys = []
319+
values = []
320+
for key, value in other.items():
321+
new_key, new_extra_names = self.__convert_to_ast(key)
322+
if new_extra_names is not None:
323+
extra_names.update(new_extra_names)
324+
keys.append(new_key)
325+
new_value, new_extra_names = self.__convert_to_ast(value)
326+
if new_extra_names is not None:
327+
extra_names.update(new_extra_names)
328+
values.append(new_value)
329+
return ast.Dict(keys, values), extra_names
330+
elif type(other) in (list, tuple, set):
331+
extra_names = {}
332+
elts = []
333+
for elt in other:
334+
new_elt, new_extra_names = self.__convert_to_ast(elt)
335+
if new_extra_names is not None:
336+
extra_names.update(new_extra_names)
337+
elts.append(new_elt)
338+
ast_class = {list: ast.List, tuple: ast.Tuple, set: ast.Set}[type(other)]
339+
return ast_class(elts), extra_names
340+
else:
341+
name = self.__stringifier_dict__.create_unique_name()
342+
return ast.Name(id=name), {name: other}
343+
344+
def __convert_to_ast_getitem(self, other):
345+
if isinstance(other, slice):
346+
extra_names = {}
347+
348+
def conv(obj):
349+
if obj is None:
350+
return None
351+
new_obj, new_extra_names = self.__convert_to_ast(obj)
352+
if new_extra_names is not None:
353+
extra_names.update(new_extra_names)
354+
return new_obj
355+
298356
return ast.Slice(
299-
lower=(
300-
self.__convert_to_ast(other.start)
301-
if other.start is not None
302-
else None
303-
),
304-
upper=(
305-
self.__convert_to_ast(other.stop)
306-
if other.stop is not None
307-
else None
308-
),
309-
step=(
310-
self.__convert_to_ast(other.step)
311-
if other.step is not None
312-
else None
313-
),
314-
)
357+
lower=conv(other.start),
358+
upper=conv(other.stop),
359+
step=conv(other.step),
360+
), extra_names
315361
else:
316-
return ast.Constant(value=other)
362+
return self.__convert_to_ast(other)
317363

318364
def __get_ast(self):
319365
node = self.__ast_node__
320366
if isinstance(node, str):
321367
return ast.Name(id=node)
322368
return node
323369

324-
def __make_new(self, node):
370+
def __make_new(self, node, extra_names=None):
371+
new_extra_names = {}
372+
if self.__extra_names__ is not None:
373+
new_extra_names.update(self.__extra_names__)
374+
if extra_names is not None:
375+
new_extra_names.update(extra_names)
325376
stringifier = _Stringifier(
326377
node,
327378
self.__globals__,
328379
self.__owner__,
329380
self.__forward_is_class__,
330381
stringifier_dict=self.__stringifier_dict__,
382+
extra_names=new_extra_names or None,
331383
)
332384
self.__stringifier_dict__.stringifiers.append(stringifier)
333385
return stringifier
@@ -343,27 +395,37 @@ def __getitem__(self, other):
343395
if self.__ast_node__ == "__classdict__":
344396
raise KeyError
345397
if isinstance(other, tuple):
346-
elts = [self.__convert_to_ast(elt) for elt in other]
398+
extra_names = {}
399+
elts = []
400+
for elt in other:
401+
new_elt, new_extra_names = self.__convert_to_ast_getitem(elt)
402+
if new_extra_names is not None:
403+
extra_names.update(new_extra_names)
404+
elts.append(new_elt)
347405
other = ast.Tuple(elts)
348406
else:
349-
other = self.__convert_to_ast(other)
407+
other, extra_names = self.__convert_to_ast_getitem(other)
350408
assert isinstance(other, ast.AST), repr(other)
351-
return self.__make_new(ast.Subscript(self.__get_ast(), other))
409+
return self.__make_new(ast.Subscript(self.__get_ast(), other), extra_names)
352410

353411
def __getattr__(self, attr):
354412
return self.__make_new(ast.Attribute(self.__get_ast(), attr))
355413

356414
def __call__(self, *args, **kwargs):
357-
return self.__make_new(
358-
ast.Call(
359-
self.__get_ast(),
360-
[self.__convert_to_ast(arg) for arg in args],
361-
[
362-
ast.keyword(key, self.__convert_to_ast(value))
363-
for key, value in kwargs.items()
364-
],
365-
)
366-
)
415+
extra_names = {}
416+
ast_args = []
417+
for arg in args:
418+
new_arg, new_extra_names = self.__convert_to_ast(arg)
419+
if new_extra_names is not None:
420+
extra_names.update(new_extra_names)
421+
ast_args.append(new_arg)
422+
ast_kwargs = []
423+
for key, value in kwargs.items():
424+
new_value, new_extra_names = self.__convert_to_ast(value)
425+
if new_extra_names is not None:
426+
extra_names.update(new_extra_names)
427+
ast_kwargs.append(ast.keyword(key, new_value))
428+
return self.__make_new(ast.Call(self.__get_ast(), ast_args, ast_kwargs), extra_names)
367429

368430
def __iter__(self):
369431
yield self.__make_new(ast.Starred(self.__get_ast()))
@@ -378,8 +440,9 @@ def __format__(self, format_spec):
378440

379441
def _make_binop(op: ast.AST):
380442
def binop(self, other):
443+
rhs, extra_names = self.__convert_to_ast(other)
381444
return self.__make_new(
382-
ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
445+
ast.BinOp(self.__get_ast(), op, rhs), extra_names
383446
)
384447

385448
return binop
@@ -402,8 +465,9 @@ def binop(self, other):
402465

403466
def _make_rbinop(op: ast.AST):
404467
def rbinop(self, other):
468+
new_other, extra_names = self.__convert_to_ast(other)
405469
return self.__make_new(
406-
ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
470+
ast.BinOp(new_other, op, self.__get_ast()), extra_names
407471
)
408472

409473
return rbinop
@@ -426,12 +490,14 @@ def rbinop(self, other):
426490

427491
def _make_compare(op):
428492
def compare(self, other):
493+
rhs, extra_names = self.__convert_to_ast(other)
429494
return self.__make_new(
430495
ast.Compare(
431496
left=self.__get_ast(),
432497
ops=[op],
433-
comparators=[self.__convert_to_ast(other)],
434-
)
498+
comparators=[rhs],
499+
),
500+
extra_names,
435501
)
436502

437503
return compare
@@ -459,13 +525,15 @@ def unary_op(self):
459525

460526

461527
class _StringifierDict(dict):
462-
def __init__(self, namespace, globals=None, owner=None, is_class=False):
528+
def __init__(self, namespace, *, globals=None, owner=None, is_class=False, format):
463529
super().__init__(namespace)
464530
self.namespace = namespace
465531
self.globals = globals
466532
self.owner = owner
467533
self.is_class = is_class
468534
self.stringifiers = []
535+
self.next_id = 1
536+
self.format = format
469537

470538
def __missing__(self, key):
471539
fwdref = _Stringifier(
@@ -478,6 +546,11 @@ def __missing__(self, key):
478546
self.stringifiers.append(fwdref)
479547
return fwdref
480548

549+
def create_unique_name(self):
550+
name = f"__annotationlib_name_{self.next_id}__"
551+
self.next_id += 1
552+
return name
553+
481554

482555
def call_evaluate_function(evaluate, format, *, owner=None):
483556
"""Call an evaluate function. Evaluate functions are normally generated for
@@ -521,7 +594,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
521594
# possibly constants if the annotate function uses them directly). We then
522595
# convert each of those into a string to get an approximation of the
523596
# original source.
524-
globals = _StringifierDict({})
597+
globals = _StringifierDict({}, format=format)
525598
if annotate.__closure__:
526599
freevars = annotate.__code__.co_freevars
527600
new_closure = []
@@ -544,9 +617,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
544617
)
545618
annos = func(Format.VALUE_WITH_FAKE_GLOBALS)
546619
if _is_evaluate:
547-
return annos if isinstance(annos, str) else repr(annos)
620+
return _stringify_single(annos)
548621
return {
549-
key: val if isinstance(val, str) else repr(val)
622+
key: _stringify_single(val)
550623
for key, val in annos.items()
551624
}
552625
elif format == Format.FORWARDREF:
@@ -569,7 +642,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
569642
# that returns a bool and an defined set of attributes.
570643
namespace = {**annotate.__builtins__, **annotate.__globals__}
571644
is_class = isinstance(owner, type)
572-
globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class)
645+
globals = _StringifierDict(
646+
namespace,
647+
globals=annotate.__globals__,
648+
owner=owner,
649+
is_class=is_class,
650+
format=format,
651+
)
573652
if annotate.__closure__:
574653
freevars = annotate.__code__.co_freevars
575654
new_closure = []
@@ -619,6 +698,16 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
619698
raise ValueError(f"Invalid format: {format!r}")
620699

621700

701+
def _stringify_single(anno):
702+
if anno is ...:
703+
return "..."
704+
# We have to handle str specially to support PEP 563 stringified annotations.
705+
elif isinstance(anno, str):
706+
return anno
707+
else:
708+
return repr(anno)
709+
710+
622711
def get_annotate_from_class_namespace(obj):
623712
"""Retrieve the annotate function from a class namespace dictionary.
624713

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.