Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 2a45cec

Browse filesBrowse files
Fix crashes with comments in parentheses (#4453)
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
1 parent b4d6d86 commit 2a45cec
Copy full SHA for 2a45cec

File tree

6 files changed

+185
-34
lines changed
Filter options

6 files changed

+185
-34
lines changed

‎CHANGES.md

Copy file name to clipboardExpand all lines: CHANGES.md
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
<!-- Changes that affect Black's stable style -->
2121

22+
- Fix crashes involving comments in parenthesised return types or `X | Y` style unions.
23+
(#4453)
24+
2225
### Preview style
2326

2427
<!-- Changes that affect Black's preview style -->

‎src/black/linegen.py

Copy file name to clipboardExpand all lines: src/black/linegen.py
+49-33Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,47 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None
10791079
)
10801080

10811081

1082+
def _ensure_trailing_comma(
1083+
leaves: List[Leaf], original: Line, opening_bracket: Leaf
1084+
) -> bool:
1085+
if not leaves:
1086+
return False
1087+
# Ensure a trailing comma for imports
1088+
if original.is_import:
1089+
return True
1090+
# ...and standalone function arguments
1091+
if not original.is_def:
1092+
return False
1093+
if opening_bracket.value != "(":
1094+
return False
1095+
# Don't add commas if we already have any commas
1096+
if any(
1097+
leaf.type == token.COMMA
1098+
and (
1099+
Preview.typed_params_trailing_comma not in original.mode
1100+
or not is_part_of_annotation(leaf)
1101+
)
1102+
for leaf in leaves
1103+
):
1104+
return False
1105+
1106+
# Find a leaf with a parent (comments don't have parents)
1107+
leaf_with_parent = next((leaf for leaf in leaves if leaf.parent), None)
1108+
if leaf_with_parent is None:
1109+
return True
1110+
# Don't add commas inside parenthesized return annotations
1111+
if get_annotation_type(leaf_with_parent) == "return":
1112+
return False
1113+
# Don't add commas inside PEP 604 unions
1114+
if (
1115+
leaf_with_parent.parent
1116+
and leaf_with_parent.parent.next_sibling
1117+
and leaf_with_parent.parent.next_sibling.type == token.VBAR
1118+
):
1119+
return False
1120+
return True
1121+
1122+
10821123
def bracket_split_build_line(
10831124
leaves: List[Leaf],
10841125
original: Line,
@@ -1099,40 +1140,15 @@ def bracket_split_build_line(
10991140
if component is _BracketSplitComponent.body:
11001141
result.inside_brackets = True
11011142
result.depth += 1
1102-
if leaves:
1103-
no_commas = (
1104-
# Ensure a trailing comma for imports and standalone function arguments
1105-
original.is_def
1106-
# Don't add one after any comments or within type annotations
1107-
and opening_bracket.value == "("
1108-
# Don't add one if there's already one there
1109-
and not any(
1110-
leaf.type == token.COMMA
1111-
and (
1112-
Preview.typed_params_trailing_comma not in original.mode
1113-
or not is_part_of_annotation(leaf)
1114-
)
1115-
for leaf in leaves
1116-
)
1117-
# Don't add one inside parenthesized return annotations
1118-
and get_annotation_type(leaves[0]) != "return"
1119-
# Don't add one inside PEP 604 unions
1120-
and not (
1121-
leaves[0].parent
1122-
and leaves[0].parent.next_sibling
1123-
and leaves[0].parent.next_sibling.type == token.VBAR
1124-
)
1125-
)
1126-
1127-
if original.is_import or no_commas:
1128-
for i in range(len(leaves) - 1, -1, -1):
1129-
if leaves[i].type == STANDALONE_COMMENT:
1130-
continue
1143+
if _ensure_trailing_comma(leaves, original, opening_bracket):
1144+
for i in range(len(leaves) - 1, -1, -1):
1145+
if leaves[i].type == STANDALONE_COMMENT:
1146+
continue
11311147

1132-
if leaves[i].type != token.COMMA:
1133-
new_comma = Leaf(token.COMMA, ",")
1134-
leaves.insert(i + 1, new_comma)
1135-
break
1148+
if leaves[i].type != token.COMMA:
1149+
new_comma = Leaf(token.COMMA, ",")
1150+
leaves.insert(i + 1, new_comma)
1151+
break
11361152

11371153
leaves_to_track: Set[LeafID] = set()
11381154
if component is _BracketSplitComponent.head:

‎src/black/nodes.py

Copy file name to clipboardExpand all lines: src/black/nodes.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,7 @@ def get_annotation_type(leaf: Leaf) -> Literal["return", "param", None]:
10121012

10131013
def is_part_of_annotation(leaf: Leaf) -> bool:
10141014
"""Returns whether this leaf is part of a type annotation."""
1015+
assert leaf.parent is not None
10151016
return get_annotation_type(leaf) is not None
10161017

10171018

‎src/black/trans.py

Copy file name to clipboardExpand all lines: src/black/trans.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def do_match(self, line: Line) -> TMatchResult:
488488
break
489489
i += 1
490490

491-
if not is_part_of_annotation(leaf) and not contains_comment:
491+
if not contains_comment and not is_part_of_annotation(leaf):
492492
string_indices.append(idx)
493493

494494
# Advance to the next non-STRING leaf.

‎tests/data/cases/funcdef_return_type_trailing_comma.py

Copy file name to clipboardExpand all lines: tests/data/cases/funcdef_return_type_trailing_comma.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def SimplePyFn(
142142
Buffer[UInt8, 2],
143143
Buffer[UInt8, 2],
144144
]: ...
145+
145146
# output
146147
# normal, short, function definition
147148
def foo(a, b) -> tuple[int, float]: ...

‎tests/data/cases/function_trailing_comma.py

Copy file name to clipboardExpand all lines: tests/data/cases/function_trailing_comma.py
+130Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,64 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr
6060
argument1, (one, two,), argument4, argument5, argument6
6161
)
6262

63+
def foo() -> (
64+
# comment inside parenthesised return type
65+
int
66+
):
67+
...
68+
69+
def foo() -> (
70+
# comment inside parenthesised return type
71+
# more
72+
int
73+
# another
74+
):
75+
...
76+
77+
def foo() -> (
78+
# comment inside parenthesised new union return type
79+
int | str | bytes
80+
):
81+
...
82+
83+
def foo() -> (
84+
# comment inside plain tuple
85+
):
86+
pass
87+
88+
def foo(arg: (# comment with non-return annotation
89+
int
90+
# comment with non-return annotation
91+
)):
92+
pass
93+
94+
def foo(arg: (# comment with non-return annotation
95+
int | range | memoryview
96+
# comment with non-return annotation
97+
)):
98+
pass
99+
100+
def foo(arg: (# only before
101+
int
102+
)):
103+
pass
104+
105+
def foo(arg: (
106+
int
107+
# only after
108+
)):
109+
pass
110+
111+
variable: ( # annotation
112+
because
113+
# why not
114+
)
115+
116+
variable: (
117+
because
118+
# why not
119+
)
120+
63121
# output
64122

65123
def f(
@@ -176,3 +234,75 @@ def func() -> (
176234
argument5,
177235
argument6,
178236
)
237+
238+
239+
def foo() -> (
240+
# comment inside parenthesised return type
241+
int
242+
): ...
243+
244+
245+
def foo() -> (
246+
# comment inside parenthesised return type
247+
# more
248+
int
249+
# another
250+
): ...
251+
252+
253+
def foo() -> (
254+
# comment inside parenthesised new union return type
255+
int
256+
| str
257+
| bytes
258+
): ...
259+
260+
261+
def foo() -> (
262+
# comment inside plain tuple
263+
):
264+
pass
265+
266+
267+
def foo(
268+
arg: ( # comment with non-return annotation
269+
int
270+
# comment with non-return annotation
271+
),
272+
):
273+
pass
274+
275+
276+
def foo(
277+
arg: ( # comment with non-return annotation
278+
int
279+
| range
280+
| memoryview
281+
# comment with non-return annotation
282+
),
283+
):
284+
pass
285+
286+
287+
def foo(arg: int): # only before
288+
pass
289+
290+
291+
def foo(
292+
arg: (
293+
int
294+
# only after
295+
),
296+
):
297+
pass
298+
299+
300+
variable: ( # annotation
301+
because
302+
# why not
303+
)
304+
305+
variable: (
306+
because
307+
# why not
308+
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.