From 185df3c3258a9447b9614a466cceb2b6c3590ac3 Mon Sep 17 00:00:00 2001 From: isidentical Date: Tue, 24 Dec 2019 19:49:01 +0300 Subject: [PATCH 1/8] bpo-38870: Implement a basic precedence system --- Lib/ast.py | 139 ++++++++++++++++++++++++++++++++++----- Lib/test/test_unparse.py | 31 +++++++++ 2 files changed, 155 insertions(+), 15 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index 62f6e075a09fdf2..3abcbecd0258ab9 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -27,6 +27,7 @@ import sys from _ast import * from contextlib import contextmanager, nullcontext +from enum import IntEnum, auto def parse(source, filename='', mode='exec', *, @@ -557,6 +558,36 @@ def __new__(cls, *args, **kwargs): # We unparse those infinities to INFSTR. _INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) +class _Precedence(IntEnum): + """Precedence table that originated from python grammar.""" + + TUPLE = auto() + YIELD = auto() # 'yield', 'yield from' + TEST = auto() # 'if'-'else', 'lambda' + OR = auto() # 'or' + AND = auto() # 'and' + NOT = auto() # 'not' + CMP = auto() # '<', '>', '==', '>=', '<=', '!=', + # 'in', 'not in', 'is', 'is not' + EXPR = auto() + BOR = EXPR # '|' + BXOR = auto() # '^' + BAND = auto() # '&' + SHIFT = auto() # '<<', '>>' + ARITH = auto() # '+', '-' + TERM = auto() # '*', '@', '/', '%', '//' + FACTOR = auto() # unary '+', '-', '~' + POWER = auto() # '**' + AWAIT = auto() # 'await' + ATOM = auto() + + def next(self, by=1): + values = tuple(self.__class__.__members__.values()) + if values[-1] == self: + return self # max precedence + else: + return self.__class__(self + by) + class _Unparser(NodeVisitor): """Methods in this class recursively traverse an AST and output source code for the abstract syntax; original formatting @@ -565,6 +596,7 @@ class _Unparser(NodeVisitor): def __init__(self): self._source = [] self._buffer = [] + self._precedences = {} self._indent = 0 def interleave(self, inter, f, seq): @@ -622,6 +654,17 @@ def delimit_if(self, start, end, condition): else: return nullcontext() + def require_parens(self, precedence, node): + """Shortcut to adding precedence related parens""" + return self.delimit_if("(", ")", self.get_precedence(node) > precedence) + + def get_precedence(self, node): + return self._precedences.get(node, _Precedence.TEST) + + def set_precedence(self, precedence, *nodes): + for node in nodes: + self._precedences[node] = precedence + def traverse(self, node): if isinstance(node, list): for item in node: @@ -642,10 +685,12 @@ def visit_Module(self, node): def visit_Expr(self, node): self.fill() + self.set_precedence(_Precedence.TEST, node.value) self.traverse(node.value) def visit_NamedExpr(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.TUPLE, node): + self.set_precedence(_Precedence.ATOM, node.target, node.value) self.traverse(node.target) self.write(" := ") self.traverse(node.value) @@ -720,24 +765,27 @@ def visit_Nonlocal(self, node): self.interleave(lambda: self.write(", "), self.write, node.names) def visit_Await(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.AWAIT, node): self.write("await") if node.value: self.write(" ") + self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) def visit_Yield(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.YIELD, node): self.write("yield") if node.value: self.write(" ") + self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) def visit_YieldFrom(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.YIELD, node): self.write("yield from") if node.value: self.write(" ") + self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) def visit_Raise(self, node): @@ -904,7 +952,9 @@ def _fstring_Constant(self, node, write): def _fstring_FormattedValue(self, node, write): write("{") - expr = type(self)().visit(node.value).rstrip("\n") + unparser = type(self)() + unparser.set_precedence(_Precedence.TEST.next(), node.value) + expr = unparser.visit(node.value).rstrip("\n") if expr.startswith("{"): write(" ") # Separate pair of opening brackets as "{ {" write(expr) @@ -980,19 +1030,23 @@ def visit_comprehension(self, node): self.write(" async for ") else: self.write(" for ") + self.set_precedence(_Precedence.TUPLE, node.target) self.traverse(node.target) self.write(" in ") + self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs) self.traverse(node.iter) for if_clause in node.ifs: self.write(" if ") self.traverse(if_clause) def visit_IfExp(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.TEST, node): + self.set_precedence(_Precedence.TEST.next(), node.body, node.test) self.traverse(node.body) self.write(" if ") self.traverse(node.test) self.write(" else ") + self.set_precedence(_Precedence.TEST, node.orelse) self.traverse(node.orelse) def visit_Set(self, node): @@ -1013,6 +1067,7 @@ def write_item(item): # for dictionary unpacking operator in dicts {**{'y': 2}} # see PEP 448 for details self.write("**") + self.set_precedence(_Precedence.EXPR, v) self.traverse(v) else: write_key_value_pair(k, v) @@ -1032,11 +1087,20 @@ def visit_Tuple(self, node): self.interleave(lambda: self.write(", "), self.traverse, node.elts) unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} + unop_precedence = { + "~": _Precedence.FACTOR, + "not": _Precedence.NOT, + "+": _Precedence.FACTOR, + "-": _Precedence.FACTOR + } def visit_UnaryOp(self, node): - with self.delimit("(", ")"): - self.write(self.unop[node.op.__class__.__name__]) + operator = self.unop[node.op.__class__.__name__] + operator_precedence = self.unop_precedence[operator] + with self.require_parens(operator_precedence, node): + self.write(operator) self.write(" ") + self.set_precedence(operator_precedence, node.operand) self.traverse(node.operand) binop = { @@ -1055,10 +1119,38 @@ def visit_UnaryOp(self, node): "Pow": "**", } + binop_precedence = { + "+": _Precedence.ARITH, + "-": _Precedence.ARITH, + "*": _Precedence.TERM, + "@": _Precedence.TERM, + "/": _Precedence.TERM, + "%": _Precedence.TERM, + "<<": _Precedence.SHIFT, + ">>": _Precedence.SHIFT, + "|": _Precedence.BOR, + "^": _Precedence.BXOR, + "&": _Precedence.BAND, + "//": _Precedence.TERM, + "**": _Precedence.POWER, + } + + binop_rassoc = frozenset(("**",)) def visit_BinOp(self, node): - with self.delimit("(", ")"): + operator = self.binop[node.op.__class__.__name__] + operator_precedence = self.binop_precedence[operator] + with self.require_parens(operator_precedence, node): + if operator in self.binop_rassoc: + left_precedence = operator_precedence.next() + right_precedence = operator_precedence + else: + left_precedence = operator_precedence + right_precedence = operator_precedence.next() + + self.set_precedence(left_precedence, node.left) self.traverse(node.left) - self.write(" " + self.binop[node.op.__class__.__name__] + " ") + self.write(f" {operator} ") + self.set_precedence(right_precedence, node.right) self.traverse(node.right) cmpops = { @@ -1075,20 +1167,32 @@ def visit_BinOp(self, node): } def visit_Compare(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.CMP, node): + self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators) self.traverse(node.left) for o, e in zip(node.ops, node.comparators): self.write(" " + self.cmpops[o.__class__.__name__] + " ") self.traverse(e) boolops = {"And": "and", "Or": "or"} + boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR} def visit_BoolOp(self, node): - with self.delimit("(", ")"): - s = " %s " % self.boolops[node.op.__class__.__name__] - self.interleave(lambda: self.write(s), self.traverse, node.values) + operator = self.boolops[node.op.__class__.__name__] + operator_precedence = self.boolop_precedence[operator] + + def increasing_level_traverse(node): + nonlocal operator_precedence + operator_precedence = operator_precedence.next() + self.set_precedence(operator_precedence, node) + self.traverse(node) + + with self.require_parens(operator_precedence, node): + s = f" {operator} " + self.interleave(lambda: self.write(s), increasing_level_traverse, node.values) def visit_Attribute(self, node): + self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) # Special case: 3.__abs__() is a syntax error, so if node.value # is an integer literal then we need to either parenthesize @@ -1099,6 +1203,7 @@ def visit_Attribute(self, node): self.write(node.attr) def visit_Call(self, node): + self.set_precedence(_Precedence.ATOM, node.func) self.traverse(node.func) with self.delimit("(", ")"): comma = False @@ -1116,18 +1221,21 @@ def visit_Call(self, node): self.traverse(e) def visit_Subscript(self, node): + self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) with self.delimit("[", "]"): self.traverse(node.slice) def visit_Starred(self, node): self.write("*") + self.set_precedence(_Precedence.EXPR, node.value) self.traverse(node.value) def visit_Ellipsis(self, node): self.write("...") def visit_Index(self, node): + self.set_precedence(_Precedence.TUPLE, node.value) self.traverse(node.value) def visit_Slice(self, node): @@ -1209,10 +1317,11 @@ def visit_keyword(self, node): self.traverse(node.value) def visit_Lambda(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.TEST, node): self.write("lambda ") self.traverse(node.args) self.write(": ") + self.set_precedence(_Precedence.TEST, node.body) self.traverse(node.body) def visit_alias(self, node): diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 45d819f175bb933..7e142dce699b8ed 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -125,6 +125,13 @@ def check_roundtrip(self, code1): def check_invalid(self, node, raises=ValueError): self.assertRaises(raises, ast.unparse, node) + def check_src_roundtrip(self, code1, code2=None, strip=True): + code2 = code2 or code1 + code1 = ast.unparse(ast.parse(code1)) + if strip: + code1 = code1.strip() + self.assertEqual(code2, code1) + class UnparseTestCase(ASTTestCase): # Tests for specific bugs found in earlier versions of unparse @@ -279,6 +286,30 @@ def test_invalid_set(self): self.check_invalid(ast.Set(elts=[])) +class CosmeticTestCase(ASTTestCase): + """Test if there are cosmetic issues caused by unnecesary additions""" + + def test_simple_expressions_parens(self): + self.check_src_roundtrip("(a := b)") + self.check_src_roundtrip("await x") + self.check_src_roundtrip("x if x else y") + self.check_src_roundtrip("lambda x: x") + self.check_src_roundtrip("1 + 1") + self.check_src_roundtrip("~ x") + self.check_src_roundtrip("x and y") + self.check_src_roundtrip("x and y and z") + self.check_src_roundtrip("x and (y and x)") + self.check_src_roundtrip("(x and y) and z") + self.check_src_roundtrip("(x ** y) ** z ** q") + self.check_src_roundtrip("x >> y") + self.check_src_roundtrip("x << y") + self.check_src_roundtrip("x >> y and x >> z") + self.check_src_roundtrip("x + y - z * q ^ t ** k") + self.check_src_roundtrip("P * V if P and V else n * R * T") + self.check_src_roundtrip("lambda P, V, n: P * V == n * R * T") + self.check_src_roundtrip("flag & (other | foo)") + + class DirectoryTestCase(ASTTestCase): """Test roundtrip behaviour on all files in Lib and Lib/test.""" From c98db1f7e4c66c2e01244961507ff1689cb9641b Mon Sep 17 00:00:00 2001 From: isidentical Date: Tue, 24 Dec 2019 20:00:17 +0300 Subject: [PATCH 2/8] ensure testing item is in ast module --- Lib/test/test_ast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 51a7c1af1ffe702..1552d9893308ea5 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -331,7 +331,7 @@ def test_base_classes(self): def test_field_attr_existence(self): for name, item in ast.__dict__.items(): - if isinstance(item, type) and name != 'AST' and name[0].isupper(): + if isinstance(item, type) and name != 'AST' and "ast" in item.__module__ and name[0].isupper(): x = item() if isinstance(x, ast.AST): self.assertEqual(type(x._fields), tuple) From 73050b0363e2a8476e3d71b7c8438304b45db5fd Mon Sep 17 00:00:00 2001 From: isidentical Date: Sun, 29 Dec 2019 22:58:50 +0300 Subject: [PATCH 3/8] refactor field attr existence test, add more tests about arihmetic, align comment --- Lib/ast.py | 2 +- Lib/test/test_ast.py | 8 +++++++- Lib/test/test_unparse.py | 4 ++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index 3abcbecd0258ab9..91bde956dbd747b 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -568,7 +568,7 @@ class _Precedence(IntEnum): AND = auto() # 'and' NOT = auto() # 'not' CMP = auto() # '<', '>', '==', '>=', '<=', '!=', - # 'in', 'not in', 'is', 'is not' + # 'in', 'not in', 'is', 'is not' EXPR = auto() BOR = EXPR # '|' BXOR = auto() # '^' diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 1552d9893308ea5..1e206e2fc324b92 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -243,6 +243,12 @@ def to_tuple(t): class AST_Tests(unittest.TestCase): + def _is_ast_node(self, name, node): + if not (name[0].isupper() and isinstance(node, type)): + return False + if node is not ast.AST and issubclass(node, ast.AST): + return True + def _assertTrueorder(self, ast_node, parent_pos): if not isinstance(ast_node, ast.AST) or ast_node._fields is None: return @@ -331,7 +337,7 @@ def test_base_classes(self): def test_field_attr_existence(self): for name, item in ast.__dict__.items(): - if isinstance(item, type) and name != 'AST' and "ast" in item.__module__ and name[0].isupper(): + if self._is_ast_node(name, item): x = item() if isinstance(x, ast.AST): self.assertEqual(type(x._fields), tuple) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 7e142dce699b8ed..55ef4d9e4d0f5c1 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -295,6 +295,10 @@ def test_simple_expressions_parens(self): self.check_src_roundtrip("x if x else y") self.check_src_roundtrip("lambda x: x") self.check_src_roundtrip("1 + 1") + self.check_src_roundtrip("1 + 2 / 3") + self.check_src_roundtrip("(1 + 2) / 3") + self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2)") + self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2) ** 2") self.check_src_roundtrip("~ x") self.check_src_roundtrip("x and y") self.check_src_roundtrip("x and y and z") From 4576f36413ff5dd53b54674c68212cc3f95ff2be Mon Sep 17 00:00:00 2001 From: isidentical Date: Sun, 29 Dec 2019 23:12:13 +0300 Subject: [PATCH 4/8] try to tweak _is_ast_node --- Lib/test/test_ast.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 1e206e2fc324b92..e131c35c7fede42 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -244,10 +244,11 @@ def to_tuple(t): class AST_Tests(unittest.TestCase): def _is_ast_node(self, name, node): - if not (name[0].isupper() and isinstance(node, type)): + if not isinstance(node, type): return False - if node is not ast.AST and issubclass(node, ast.AST): - return True + if "ast" not in node.__module__: + return False + return isinstance(node, type) and name != 'AST' and name[0].isupper() def _assertTrueorder(self, ast_node, parent_pos): if not isinstance(ast_node, ast.AST) or ast_node._fields is None: From 8a570593a9f6611bac118a0d4ee55132cc135be4 Mon Sep 17 00:00:00 2001 From: isidentical Date: Sun, 29 Dec 2019 23:28:52 +0300 Subject: [PATCH 5/8] shorten check, remove by= parameter --- Lib/ast.py | 4 ++-- Lib/test/test_ast.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index 91bde956dbd747b..32304edce118edc 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -581,12 +581,12 @@ class _Precedence(IntEnum): AWAIT = auto() # 'await' ATOM = auto() - def next(self, by=1): + def next(self): values = tuple(self.__class__.__members__.values()) if values[-1] == self: return self # max precedence else: - return self.__class__(self + by) + return self.__class__(self + 1) class _Unparser(NodeVisitor): """Methods in this class recursively traverse an AST and diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index e131c35c7fede42..6a3aa5ae6320628 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -248,7 +248,7 @@ def _is_ast_node(self, name, node): return False if "ast" not in node.__module__: return False - return isinstance(node, type) and name != 'AST' and name[0].isupper() + return name != 'AST' and name[0].isupper() def _assertTrueorder(self, ast_node, parent_pos): if not isinstance(ast_node, ast.AST) or ast_node._fields is None: From a9908eca54bf45c5596311e97e8ba98d98c89350 Mon Sep 17 00:00:00 2001 From: isidentical Date: Sat, 4 Jan 2020 13:49:46 +0300 Subject: [PATCH 6/8] unary op in compare --- Lib/test/test_unparse.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 55ef4d9e4d0f5c1..1336526a44d1c07 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -312,6 +312,8 @@ def test_simple_expressions_parens(self): self.check_src_roundtrip("P * V if P and V else n * R * T") self.check_src_roundtrip("lambda P, V, n: P * V == n * R * T") self.check_src_roundtrip("flag & (other | foo)") + self.check_src_roundtrip("not x == y") + self.check_src_roundtrip("x == (not y)") class DirectoryTestCase(ASTTestCase): From f7aca507e6488fcdda51c97d2a829c04682ddddf Mon Sep 17 00:00:00 2001 From: isidentical Date: Tue, 7 Jan 2020 15:56:59 +0300 Subject: [PATCH 7/8] simplify next() function for _Precedence --- Lib/ast.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index 2c9bf5a88c9d563..d1efd0b40757b44 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -585,11 +585,10 @@ class _Precedence(IntEnum): ATOM = auto() def next(self): - values = tuple(self.__class__.__members__.values()) - if values[-1] == self: - return self # max precedence - else: + try: return self.__class__(self + 1) + except ValueError: + return self class _Unparser(NodeVisitor): """Methods in this class recursively traverse an AST and From 1b07eb0a22d50017c4a0701a22eb97b49ba8ffc0 Mon Sep 17 00:00:00 2001 From: isidentical Date: Sat, 11 Jan 2020 16:17:36 +0300 Subject: [PATCH 8/8] test & support top-level yield --- Lib/ast.py | 2 +- Lib/test/test_unparse.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Lib/ast.py b/Lib/ast.py index d1efd0b40757b44..49d7e4a6281aac1 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -687,7 +687,7 @@ def visit_Module(self, node): def visit_Expr(self, node): self.fill() - self.set_precedence(_Precedence.TEST, node.value) + self.set_precedence(_Precedence.YIELD, node.value) self.traverse(node.value) def visit_NamedExpr(self, node): diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index f48a9e859bc4ff6..f7fcb2bffe89199 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -316,6 +316,10 @@ def test_simple_expressions_parens(self): self.check_src_roundtrip("flag & (other | foo)") self.check_src_roundtrip("not x == y") self.check_src_roundtrip("x == (not y)") + self.check_src_roundtrip("yield x") + self.check_src_roundtrip("yield from x") + self.check_src_roundtrip("call((yield x))") + self.check_src_roundtrip("return x + (yield x)") class DirectoryTestCase(ASTTestCase):