Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit db3e5fc

Browse filesBrowse files
committed
Extract task cancellation as utility function
1 parent e48d160 commit db3e5fc
Copy full SHA for db3e5fc

File tree

Expand file treeCollapse file tree

4 files changed

+185
-60
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+185
-60
lines changed

‎src/graphql/execution/execute.py

Copy file name to clipboardExpand all lines: src/graphql/execution/execute.py
+11-60Lines changed: 11 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
from asyncio import (
66
CancelledError,
7-
create_task,
87
ensure_future,
9-
gather,
108
shield,
119
wait_for,
1210
)
@@ -52,6 +50,7 @@
5250
RefMap,
5351
Undefined,
5452
async_reduce,
53+
gather_with_cancel,
5554
inspect,
5655
is_iterable,
5756
)
@@ -466,21 +465,9 @@ async def get_results() -> dict[str, Any]:
466465
field = awaitable_fields[0]
467466
results[field] = await results[field]
468467
else:
469-
tasks = [
470-
create_task(results[field]) # type: ignore[arg-type]
471-
for field in awaitable_fields
472-
]
473-
474-
try:
475-
awaited_results = await gather(*tasks)
476-
except Exception:
477-
# Cancel unfinished tasks before raising the exception
478-
for task in tasks:
479-
if not task.done():
480-
task.cancel()
481-
await gather(*tasks, return_exceptions=True)
482-
raise
483-
468+
awaited_results = await gather_with_cancel(
469+
*(results[field] for field in awaitable_fields)
470+
)
484471
results.update(zip(awaitable_fields, awaited_results))
485472

486473
return results
@@ -911,20 +898,9 @@ async def complete_async_iterator_value(
911898
index = awaitable_indices[0]
912899
completed_results[index] = await completed_results[index]
913900
else:
914-
tasks = [
915-
create_task(completed_results[index]) for index in awaitable_indices
916-
]
917-
918-
try:
919-
awaited_results = await gather(*tasks)
920-
except Exception:
921-
# Cancel unfinished tasks before raising the exception
922-
for task in tasks:
923-
if not task.done():
924-
task.cancel()
925-
await gather(*tasks, return_exceptions=True)
926-
raise
927-
901+
awaited_results = await gather_with_cancel(
902+
*(completed_results[index] for index in awaitable_indices)
903+
)
928904
for index, sub_result in zip(awaitable_indices, awaited_results):
929905
completed_results[index] = sub_result
930906
return completed_results
@@ -1023,20 +999,9 @@ async def get_completed_results() -> list[Any]:
1023999
index = awaitable_indices[0]
10241000
completed_results[index] = await completed_results[index]
10251001
else:
1026-
tasks = [
1027-
create_task(completed_results[index]) for index in awaitable_indices
1028-
]
1029-
1030-
try:
1031-
awaited_results = await gather(*tasks)
1032-
except Exception:
1033-
# Cancel unfinished tasks before raising the exception
1034-
for task in tasks:
1035-
if not task.done():
1036-
task.cancel()
1037-
await gather(*tasks, return_exceptions=True)
1038-
raise
1039-
1002+
awaited_results = await gather_with_cancel(
1003+
*(completed_results[index] for index in awaitable_indices)
1004+
)
10401005
for index, sub_result in zip(awaitable_indices, awaited_results):
10411006
completed_results[index] = sub_result
10421007
return completed_results
@@ -2123,21 +2088,7 @@ def default_type_resolver(
21232088
if awaitable_is_type_of_results:
21242089
# noinspection PyShadowingNames
21252090
async def get_type() -> str | None:
2126-
tasks = [
2127-
create_task(result) # type: ignore[arg-type]
2128-
for result in awaitable_is_type_of_results
2129-
]
2130-
2131-
try:
2132-
is_type_of_results = await gather(*tasks)
2133-
except Exception:
2134-
# Cancel unfinished tasks before raising the exception
2135-
for task in tasks:
2136-
if not task.done():
2137-
task.cancel()
2138-
await gather(*tasks, return_exceptions=True)
2139-
raise
2140-
2091+
is_type_of_results = await gather_with_cancel(*awaitable_is_type_of_results)
21412092
for is_type_of_result, type_ in zip(is_type_of_results, awaitable_types):
21422093
if is_type_of_result:
21432094
return type_.name

‎src/graphql/pyutils/__init__.py

Copy file name to clipboardExpand all lines: src/graphql/pyutils/__init__.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
from .async_reduce import async_reduce
12+
from .gather_with_cancel import gather_with_cancel
1213
from .convert_case import camel_to_snake, snake_to_camel
1314
from .cached_property import cached_property
1415
from .description import (
@@ -52,6 +53,7 @@
5253
"cached_property",
5354
"camel_to_snake",
5455
"did_you_mean",
56+
"gather_with_cancel",
5557
"group_by",
5658
"identity_func",
5759
"inspect",
+39Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Run awaitables concurrently with cancellation support."""
2+
3+
from __future__ import annotations
4+
5+
from asyncio import Task, create_task, gather
6+
from typing import Any, Awaitable
7+
8+
__all__ = ["gather_with_cancel"]
9+
10+
11+
async def gather_with_cancel(*awaitables: Awaitable[Any]) -> list[Any]:
12+
"""Run awaitable objects in the sequence concurrently.
13+
14+
The first raised exception is immediately propagated to the task that awaits
15+
on this function and all pending awaitables in the sequence will be cancelled.
16+
17+
This is different from the default behavior or `asyncio.gather` which waits
18+
for all tasks to complete even if one of them raises an exception. It is also
19+
different from `asyncio.gather` with `return_exceptions` set, which does not
20+
cancel the other tasks when one of them raises an exception.
21+
"""
22+
try:
23+
tasks: list[Task[Any]] = [
24+
aw if isinstance(aw, Task) else create_task(aw) # type: ignore[arg-type]
25+
for aw in awaitables
26+
]
27+
except TypeError:
28+
return await gather(*awaitables) # type: ignore[arg-type]
29+
try:
30+
return await gather(*tasks)
31+
except Exception:
32+
print("HEINIMANN")
33+
for task in tasks:
34+
print("TASK CANCEL", task)
35+
if not task.done():
36+
print("NOT DONE, CANCELING", task)
37+
task.cancel()
38+
await gather(*tasks, return_exceptions=True)
39+
raise
+133Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from asyncio import Event, create_task, gather, sleep, wait_for
2+
from typing import Callable
3+
4+
import pytest
5+
6+
from graphql.pyutils import gather_with_cancel, is_awaitable
7+
8+
9+
class Controller:
10+
def reset(self, wait=False):
11+
self.event = Event()
12+
if not wait:
13+
self.event.set()
14+
self.returned = []
15+
16+
17+
controller = Controller()
18+
19+
20+
async def coroutine(value: int) -> int:
21+
"""Simple coroutine that returns a value."""
22+
if value > 2:
23+
raise RuntimeError("Oops")
24+
await controller.event.wait()
25+
controller.returned.append(value)
26+
return value
27+
28+
29+
class CustomAwaitable:
30+
"""Custom awaitable that return a value."""
31+
32+
def __init__(self, value: int):
33+
self.value = value
34+
self.coroutine = coroutine(value)
35+
36+
def __await__(self):
37+
return self.coroutine.__await__()
38+
39+
40+
awaitable_factories: dict[str, Callable] = {
41+
"coroutine": coroutine,
42+
"task": lambda value: create_task(coroutine(value)),
43+
"custom": lambda value: CustomAwaitable(value),
44+
}
45+
46+
with_all_types_of_awaitables = pytest.mark.parametrize(
47+
"type_of_awaitable", awaitable_factories
48+
)
49+
50+
51+
def describe_gather_with_cancel():
52+
@with_all_types_of_awaitables
53+
@pytest.mark.asyncio
54+
async def gathers_all_values(type_of_awaitable: str):
55+
return # !!!s
56+
factory = awaitable_factories[type_of_awaitable]
57+
values = list(range(3))
58+
59+
controller.reset()
60+
aws = [factory(i) for i in values]
61+
62+
assert await gather(*aws) == values
63+
assert controller.returned == values
64+
65+
controller.reset()
66+
aws = [factory(i) for i in values]
67+
68+
result = gather_with_cancel(*aws)
69+
assert is_awaitable(result)
70+
71+
awaited = await wait_for(result, 1)
72+
assert awaited == values
73+
74+
@with_all_types_of_awaitables
75+
@pytest.mark.asyncio
76+
async def raises_on_exception(type_of_awaitable: str):
77+
return # !!!
78+
factory = awaitable_factories[type_of_awaitable]
79+
values = list(range(4))
80+
81+
controller.reset()
82+
aws = [factory(i) for i in values]
83+
84+
with pytest.raises(RuntimeError, match="Oops"):
85+
await gather(*aws)
86+
assert controller.returned == values[:-1]
87+
88+
controller.reset()
89+
aws = [factory(i) for i in values]
90+
91+
result = gather_with_cancel(*aws)
92+
assert is_awaitable(result)
93+
94+
with pytest.raises(RuntimeError, match="Oops"):
95+
await wait_for(result, 1)
96+
assert controller.returned == values[:-1]
97+
98+
@with_all_types_of_awaitables
99+
@pytest.mark.asyncio
100+
async def cancels_on_exception(type_of_awaitable: str):
101+
factory = awaitable_factories[type_of_awaitable]
102+
values = list(range(4))
103+
104+
controller.reset(wait=True)
105+
aws = [factory(i) for i in values]
106+
107+
with pytest.raises(RuntimeError, match="Oops"):
108+
await gather(*aws)
109+
assert not controller.returned
110+
111+
# check that the standard gather continues to produce results
112+
controller.event.set()
113+
await sleep(0)
114+
assert controller.returned == values[:-1]
115+
116+
controller.reset(wait=True)
117+
aws = [factory(i) for i in values]
118+
119+
result = gather_with_cancel(*aws)
120+
assert is_awaitable(result)
121+
122+
with pytest.raises(RuntimeError, match="Oops"):
123+
await wait_for(result, 1)
124+
assert not controller.returned
125+
126+
# check that gather_with_cancel stops producing results
127+
controller.event.set()
128+
await sleep(0)
129+
if type_of_awaitable == "custom":
130+
# Cancellation of custom awaitables is not supported
131+
assert controller.returned == values[:-1]
132+
else:
133+
assert not controller.returned

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.