From d96e933b2b4ab8a1ad0114b36e58ce9bb0f034c7 Mon Sep 17 00:00:00 2001 From: isidentical Date: Thu, 2 Jan 2020 18:34:39 +0300 Subject: [PATCH 1/2] bpo-38870: Implement round tripping support for typed AST --- Lib/ast.py | 36 ++++++++++++++++++++++++++++++------ Lib/test/test_unparse.py | 38 +++++++++++++++++++++++++++++++++++--- 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index 7a43581c0e6ce69..04665f5cdb41256 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -648,6 +648,7 @@ def __init__(self): self._source = [] self._buffer = [] self._precedences = {} + self._type_ignores = {} self._indent = 0 def interleave(self, inter, f, seq): @@ -697,11 +698,16 @@ def buffer(self): return value @contextmanager - def block(self): + def block(self, node=None, *, append_type_comment = False): """A context manager for preparing the source for blocks. It adds the character':', increases the indentation on enter and decreases - the indentation on exit.""" + the indentation on exit. If *append_type_comment* is True, and *node* + is not None, this function will add given *node*'s type comment after + the colon character. + """ self.write(":") + if append_type_comment and node: + self.append_type_comment(node) self._indent += 1 yield self._indent -= 1 @@ -748,6 +754,18 @@ def get_raw_docstring(self, node): if isinstance(node, Constant) and isinstance(node.value, str): return node + def append_type_comment(self, node): + comment = None + ignore = self._type_ignores.get(node.lineno) + + if ignore: + comment = f"ignore{ignore.tag}" + else: + comment = node.type_comment + + if comment: + self.write(f" # type: {comment}") + def traverse(self, node): if isinstance(node, list): for item in node: @@ -770,7 +788,12 @@ def _write_docstring_and_traverse_body(self, node): self.traverse(node.body) def visit_Module(self, node): + self._type_ignores = { + ignore.lineno: ignore + for ignore in node.type_ignores + } self._write_docstring_and_traverse_body(node) + self._type_ignores.clear() def visit_FunctionType(self, node): with self.delimit("(", ")"): @@ -811,6 +834,7 @@ def visit_Assign(self, node): self.traverse(target) self.write(" = ") self.traverse(node.value) + self.append_type_comment(node) def visit_AugAssign(self, node): self.fill() @@ -966,7 +990,7 @@ def _function_helper(self, node, fill_suffix): if node.returns: self.write(" -> ") self.traverse(node.returns) - with self.block(): + with self.block(node, append_type_comment=True): self._write_docstring_and_traverse_body(node) def visit_For(self, node): @@ -980,7 +1004,7 @@ def _for_helper(self, fill, node): self.traverse(node.target) self.write(" in ") self.traverse(node.iter) - with self.block(): + with self.block(node, append_type_comment=True): self.traverse(node.body) if node.orelse: self.fill("else") @@ -1018,13 +1042,13 @@ def visit_While(self, node): def visit_With(self, node): self.fill("with ") self.interleave(lambda: self.write(", "), self.traverse, node.items) - with self.block(): + with self.block(node, append_type_comment=True): self.traverse(node.body) def visit_AsyncWith(self, node): self.fill("async with ") self.interleave(lambda: self.write(", "), self.traverse, node.items) - with self.block(): + with self.block(node, append_type_comment=True): self.traverse(node.body) def visit_JoinedStr(self, node): diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 2be44b246aa697b..09d0395d8b05050 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -108,20 +108,22 @@ class Foo: pass suite1 """ -docstring_prefixes = [ +docstring_prefixes = ( "", "class foo():\n ", "def foo():\n ", "async def foo():\n ", -] +) class ASTTestCase(unittest.TestCase): def assertASTEqual(self, ast1, ast2): self.assertEqual(ast.dump(ast1), ast.dump(ast2)) - def check_ast_roundtrip(self, code1, **kwargs): + def check_ast_roundtrip(self, code1, strip=True, **kwargs): ast1 = ast.parse(code1, **kwargs) code2 = ast.unparse(ast1) + if strip: + code2 = code2.strip() ast2 = ast.parse(code2, **kwargs) self.assertASTEqual(ast1, ast2) @@ -333,6 +335,36 @@ def test_function_type(self): ): self.check_ast_roundtrip(function_type, mode="func_type") + def test_type_comments(self): + for statement in ( + "a = 5 # type: int", + "a = 5 # type: int and more", + "def x(): # type: () -> None\n\tpass", + "def x(y): # type: (int) -> None and more\n\tpass", + "async def x(): # type: () -> None\n\tpass", + "async def x(y): # type: (int) -> None and more\n\tpass", + "for x in y: # type: int\n\tpass", + "async for x in y: # type: int\n\tpass", + "with x(): # type: int\n\tpass", + "async with x(): # type: int\n\tpass" + ): + self.check_ast_roundtrip(statement, type_comments=True) + + def test_type_ignore(self): + for statement in ( + "a = 5 # type: ignore", + "a = 5 # type: ignore and more", + "def x(): # type: ignore\n\tpass", + "def x(y): # type: ignore and more\n\tpass", + "async def x(): # type: ignore\n\tpass", + "async def x(y): # type: ignore and more\n\tpass", + "for x in y: # type: ignore\n\tpass", + "async for x in y: # type: ignore\n\tpass", + "with x(): # type: ignore\n\tpass", + "async with x(): # type: ignore\n\tpass" + ): + self.check_ast_roundtrip(statement, type_comments=True) + class CosmeticTestCase(ASTTestCase): """Test if there are cosmetic issues caused by unnecesary additions""" From 73d72867db0e15f305e4ba7e84cadf80a380e74a Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Tue, 17 Mar 2020 10:25:59 +0300 Subject: [PATCH 2/2] simplify block --- Lib/ast.py | 39 ++++++++++++++++----------------------- Lib/test/test_unparse.py | 5 ++--- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index 04665f5cdb41256..ad7c5389e4bdc70 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -698,16 +698,15 @@ def buffer(self): return value @contextmanager - def block(self, node=None, *, append_type_comment = False): + def block(self, *, extra = None): """A context manager for preparing the source for blocks. It adds the character':', increases the indentation on enter and decreases - the indentation on exit. If *append_type_comment* is True, and *node* - is not None, this function will add given *node*'s type comment after - the colon character. + the indentation on exit. If *extra* is given, it will be directly + appended after the colon character. """ self.write(":") - if append_type_comment and node: - self.append_type_comment(node) + if extra: + self.write(extra) self._indent += 1 yield self._indent -= 1 @@ -754,17 +753,10 @@ def get_raw_docstring(self, node): if isinstance(node, Constant) and isinstance(node.value, str): return node - def append_type_comment(self, node): - comment = None - ignore = self._type_ignores.get(node.lineno) - - if ignore: - comment = f"ignore{ignore.tag}" - else: - comment = node.type_comment - - if comment: - self.write(f" # type: {comment}") + def get_type_comment(self, node): + comment = self._type_ignores.get(node.lineno) or node.type_comment + if comment is not None: + return f" # type: {comment}" def traverse(self, node): if isinstance(node, list): @@ -789,7 +781,7 @@ def _write_docstring_and_traverse_body(self, node): def visit_Module(self, node): self._type_ignores = { - ignore.lineno: ignore + ignore.lineno: f"ignore{ignore.tag}" for ignore in node.type_ignores } self._write_docstring_and_traverse_body(node) @@ -834,7 +826,8 @@ def visit_Assign(self, node): self.traverse(target) self.write(" = ") self.traverse(node.value) - self.append_type_comment(node) + if type_comment := self.get_type_comment(node): + self.write(type_comment) def visit_AugAssign(self, node): self.fill() @@ -990,7 +983,7 @@ def _function_helper(self, node, fill_suffix): if node.returns: self.write(" -> ") self.traverse(node.returns) - with self.block(node, append_type_comment=True): + with self.block(extra=self.get_type_comment(node)): self._write_docstring_and_traverse_body(node) def visit_For(self, node): @@ -1004,7 +997,7 @@ def _for_helper(self, fill, node): self.traverse(node.target) self.write(" in ") self.traverse(node.iter) - with self.block(node, append_type_comment=True): + with self.block(extra=self.get_type_comment(node)): self.traverse(node.body) if node.orelse: self.fill("else") @@ -1042,13 +1035,13 @@ def visit_While(self, node): def visit_With(self, node): self.fill("with ") self.interleave(lambda: self.write(", "), self.traverse, node.items) - with self.block(node, append_type_comment=True): + with self.block(extra=self.get_type_comment(node)): self.traverse(node.body) def visit_AsyncWith(self, node): self.fill("async with ") self.interleave(lambda: self.write(", "), self.traverse, node.items) - with self.block(node, append_type_comment=True): + with self.block(extra=self.get_type_comment(node)): self.traverse(node.body) def visit_JoinedStr(self, node): diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 09d0395d8b05050..6d069a657d5c6af 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -119,11 +119,9 @@ class ASTTestCase(unittest.TestCase): def assertASTEqual(self, ast1, ast2): self.assertEqual(ast.dump(ast1), ast.dump(ast2)) - def check_ast_roundtrip(self, code1, strip=True, **kwargs): + def check_ast_roundtrip(self, code1, **kwargs): ast1 = ast.parse(code1, **kwargs) code2 = ast.unparse(ast1) - if strip: - code2 = code2.strip() ast2 = ast.parse(code2, **kwargs) self.assertASTEqual(ast1, ast2) @@ -337,6 +335,7 @@ def test_function_type(self): def test_type_comments(self): for statement in ( + "a = 5 # type:", "a = 5 # type: int", "a = 5 # type: int and more", "def x(): # type: () -> None\n\tpass",