Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1e77b8c

Browse filesBrowse files
committed
Use a more tolerant aclosing() context manager
1 parent e6bb12d commit 1e77b8c
Copy full SHA for 1e77b8c

File tree

6 files changed

+241
-71
lines changed
Filter options

6 files changed

+241
-71
lines changed

‎src/graphql/execution/__init__.py

Copy file name to clipboardExpand all lines: src/graphql/execution/__init__.py
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
FormattedIncrementalResult,
3131
Middleware,
3232
)
33-
from .iterators import map_async_iterable
33+
from .async_iterables import flatten_async_iterable, map_async_iterable
3434
from .middleware import MiddlewareManager
3535
from .values import get_argument_values, get_directive_values, get_variable_values
3636

@@ -58,6 +58,7 @@
5858
"FormattedIncrementalDeferResult",
5959
"FormattedIncrementalStreamResult",
6060
"FormattedIncrementalResult",
61+
"flatten_async_iterable",
6162
"map_async_iterable",
6263
"Middleware",
6364
"MiddlewareManager",
+27-21Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations # Python < 3.10
22

3+
from contextlib import AbstractAsyncContextManager
34
from typing import (
45
Any,
56
AsyncGenerator,
@@ -11,25 +12,34 @@
1112
)
1213

1314

14-
try:
15-
from contextlib import aclosing
16-
except ImportError: # python < 3.10
17-
from contextlib import asynccontextmanager
18-
19-
@asynccontextmanager # type: ignore
20-
async def aclosing(thing):
21-
try:
22-
yield thing
23-
finally:
24-
await thing.aclose()
25-
15+
__all__ = ["aclosing", "flatten_async_iterable", "map_async_iterable"]
2616

2717
T = TypeVar("T")
2818
V = TypeVar("V")
2919

3020
AsyncIterableOrGenerator = Union[AsyncGenerator[T, None], AsyncIterable[T]]
3121

32-
__all__ = ["flatten_async_iterable", "map_async_iterable"]
22+
23+
class aclosing(AbstractAsyncContextManager):
24+
"""Async context manager for safely finalizing an async iterator or generator.
25+
26+
Contrary to the function available via the standard library, this one silently
27+
ignores the case that custom iterators have no aclose() method.
28+
"""
29+
30+
def __init__(self, iterable: AsyncIterableOrGenerator[T]) -> None:
31+
self.iterable = iterable
32+
33+
async def __aenter__(self) -> AsyncIterableOrGenerator[T]:
34+
return self.iterable
35+
36+
async def __aexit__(self, *_exc_info: Any) -> None:
37+
try:
38+
aclose = self.iterable.aclose # type: ignore
39+
except AttributeError:
40+
pass # do not complain if the iterator has no aclose() method
41+
else:
42+
await aclose()
3343

3444

3545
async def flatten_async_iterable(
@@ -48,7 +58,7 @@ async def flatten_async_iterable(
4858

4959

5060
async def map_async_iterable(
51-
iterable: AsyncIterable[T], callback: Callable[[T], Awaitable[V]]
61+
iterable: AsyncIterableOrGenerator[T], callback: Callable[[T], Awaitable[V]]
5262
) -> AsyncGenerator[V, None]:
5363
"""Map an AsyncIterable over a callback function.
5464
@@ -58,10 +68,6 @@ async def map_async_iterable(
5868
the generator finishes or closes.
5969
"""
6070

61-
aiter = iterable.__aiter__()
62-
try:
63-
async for element in aiter:
64-
yield await callback(element)
65-
finally:
66-
if hasattr(aiter, "aclose"):
67-
await aiter.aclose()
71+
async with aclosing(iterable) as items: # type: ignore
72+
async for item in items:
73+
yield await callback(item)

‎src/graphql/execution/execute.py

Copy file name to clipboardExpand all lines: src/graphql/execution/execute.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@
7070
is_non_null_type,
7171
is_object_type,
7272
)
73+
from .async_iterables import flatten_async_iterable, map_async_iterable
7374
from .collect_fields import FieldsAndPatches, collect_fields, collect_subfields
74-
from .iterators import flatten_async_iterable, map_async_iterable
7575
from .middleware import MiddlewareManager
7676
from .values import get_argument_values, get_directive_values, get_variable_values
7777

‎tests/execution/test_flatten_async_iterable.py

Copy file name to clipboardExpand all lines: tests/execution/test_flatten_async_iterable.py
+83-1Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pytest import mark, raises
44

5-
from graphql.execution.iterators import flatten_async_iterable
5+
from graphql.execution import flatten_async_iterable
66

77

88
try: # pragma: no cover
@@ -129,3 +129,85 @@ async def nested2() -> AsyncGenerator[float, None]:
129129
assert await anext(doubles) == 2.2
130130
with raises(StopAsyncIteration):
131131
assert await anext(doubles)
132+
133+
@mark.asyncio
134+
async def closes_nested_async_iterators():
135+
closed = []
136+
137+
class Source:
138+
def __init__(self):
139+
self.counter = 0
140+
141+
def __aiter__(self):
142+
return self
143+
144+
async def __anext__(self):
145+
if self.counter == 2:
146+
raise StopAsyncIteration
147+
self.counter += 1
148+
return Nested(self.counter)
149+
150+
async def aclose(self):
151+
nonlocal closed
152+
closed.append(self.counter)
153+
154+
class Nested:
155+
def __init__(self, value):
156+
self.value = value
157+
self.counter = 0
158+
159+
def __aiter__(self):
160+
return self
161+
162+
async def __anext__(self):
163+
if self.counter == 2:
164+
raise StopAsyncIteration
165+
self.counter += 1
166+
return self.value + self.counter / 10
167+
168+
async def aclose(self):
169+
nonlocal closed
170+
closed.append(self.value + self.counter / 10)
171+
172+
doubles = flatten_async_iterable(Source())
173+
174+
result = [x async for x in doubles]
175+
176+
assert result == [1.1, 1.2, 2.1, 2.2]
177+
178+
assert closed == [1.2, 2.2, 2]
179+
180+
@mark.asyncio
181+
async def works_with_nested_async_iterators_that_have_no_close_method():
182+
class Source:
183+
def __init__(self):
184+
self.counter = 0
185+
186+
def __aiter__(self):
187+
return self
188+
189+
async def __anext__(self):
190+
if self.counter == 2:
191+
raise StopAsyncIteration
192+
self.counter += 1
193+
return Nested(self.counter)
194+
195+
class Nested:
196+
def __init__(self, value):
197+
self.value = value
198+
self.counter = 0
199+
200+
def __aiter__(self):
201+
return self
202+
203+
async def __anext__(self):
204+
if self.counter == 2:
205+
raise StopAsyncIteration
206+
self.counter += 1
207+
return self.value + self.counter / 10
208+
209+
doubles = flatten_async_iterable(Source())
210+
211+
result = [x async for x in doubles]
212+
213+
assert result == [1.1, 1.2, 2.1, 2.2]

‎tests/execution/test_map_async_iterable.py

Copy file name to clipboardExpand all lines: tests/execution/test_map_async_iterable.py
+47-25Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,22 @@
33
from graphql.execution import map_async_iterable
44

55

6-
async def map_doubles(x):
6+
try: # pragma: no cover
7+
anext
8+
except NameError: # pragma: no cover (Python < 3.10)
9+
# noinspection PyShadowingBuiltins
10+
async def anext(iterator):
11+
"""Return the next item from an async iterator."""
12+
return await iterator.__anext__()
13+
14+
15+
async def map_doubles(x: int) -> int:
716
return x + x
817

918

1019
def describe_map_async_iterable():
1120
@mark.asyncio
12-
async def test_inner_close_called():
13-
"""
14-
Test that a custom iterator with aclose() gets an aclose() call
15-
when outer is closed
16-
"""
17-
21+
async def inner_is_closed_when_outer_is_closed():
1822
class Inner:
1923
def __init__(self):
2024
self.closed = False
@@ -30,19 +34,14 @@ async def __anext__(self):
3034

3135
inner = Inner()
3236
outer = map_async_iterable(inner, map_doubles)
33-
it = outer.__aiter__()
34-
assert await it.__anext__() == 2
37+
iterator = outer.__aiter__()
38+
assert await anext(iterator) == 2
3539
assert not inner.closed
3640
await outer.aclose()
3741
assert inner.closed
3842

3943
@mark.asyncio
40-
async def test_inner_close_called_on_callback_err():
41-
"""
42-
Test that a custom iterator with aclose() gets an aclose() call
43-
when the callback errors and the outer iterator aborts.
44-
"""
45-
44+
async def inner_is_closed_on_callback_error():
4645
class Inner:
4746
def __init__(self):
4847
self.closed = False
@@ -62,17 +61,11 @@ async def callback(v):
6261
inner = Inner()
6362
outer = map_async_iterable(inner, callback)
6463
with raises(RuntimeError):
65-
async for _ in outer:
66-
pass
64+
await anext(outer)
6765
assert inner.closed
6866

6967
@mark.asyncio
70-
async def test_inner_exit_on_callback_err():
71-
"""
72-
Test that a custom iterator with aclose() gets an aclose() call
73-
when the callback errors and the outer iterator aborts.
74-
"""
75-
68+
async def test_inner_exits_on_callback_error():
7669
inner_exit = False
7770

7871
async def inner():
@@ -88,6 +81,35 @@ async def callback(v):
8881

8982
outer = map_async_iterable(inner(), callback)
9083
with raises(RuntimeError):
91-
async for _ in outer:
92-
pass
84+
await anext(outer)
9385
assert inner_exit
86+
87+
@mark.asyncio
88+
async def inner_has_no_close_method_when_outer_is_closed():
89+
class Inner:
90+
def __aiter__(self):
91+
return self
92+
93+
async def __anext__(self):
94+
return 1
95+
96+
outer = map_async_iterable(Inner(), map_doubles)
97+
iterator = outer.__aiter__()
98+
assert await anext(iterator) == 2
99+
await outer.aclose()
100+
101+
@mark.asyncio
102+
async def inner_has_no_close_method_on_callback_error():
103+
class Inner:
104+
def __aiter__(self):
105+
return self
106+
107+
async def __anext__(self):
108+
return 1
109+
110+
async def callback(v):
111+
raise RuntimeError()
112+
113+
outer = map_async_iterable(Inner(), callback)
114+
with raises(RuntimeError):
115+
await anext(outer)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.