From 5adfcb3eb0526d065e3c40e4a12d3045f5c20c9a Mon Sep 17 00:00:00 2001 From: Martin Richard Date: Thu, 24 Dec 2015 18:11:26 +0100 Subject: [PATCH] loop.run_in_executor should be a coroutine Since base_events.BaseEventLoop.run_in_executor returns a Future object, a caller can call it with yield from/await. However, the result of the call is not a coroutine since asyncio.iscoroutine(loop.run_in_executor(...)) returns False. It matters when one wants to use run_in_executor() in a task, such as: loop.create_task(loop.run_in_executor(...)) In this case, an exception is raised immediatly, while the task is effectively running in the executor. This patch propose to make loop.run_in_executor() be an actual coroutine function (and be tested). I believe that returning a Future and document the function as a coroutine used to be done quite often, maybe this should be fixed elsewhere too. An alternative solution would be to change the documentation, explain why it can not be used with loop.create_task() and should be used with loop.ensure_future() instead. --- asyncio/base_events.py | 3 ++- tests/test_events.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 4505732f..d1d9564f 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -521,6 +521,7 @@ def call_soon_threadsafe(self, callback, *args): self._write_to_self() return handle + @coroutine def run_in_executor(self, executor, func, *args): if (coroutines.iscoroutine(func) or coroutines.iscoroutinefunction(func)): @@ -539,7 +540,7 @@ def run_in_executor(self, executor, func, *args): if executor is None: executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) self._default_executor = executor - return futures.wrap_future(executor.submit(func, *args), loop=self) + return (yield from futures.wrap_future(executor.submit(func, *args), loop=self)) def set_default_executor(self, executor): self._default_executor = executor diff --git a/tests/test_events.py b/tests/test_events.py index f1746043..e99889d3 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -347,6 +347,7 @@ def run(arg): res, thread_id = self.loop.run_until_complete(f2) self.assertEqual(res, 'yo') self.assertNotEqual(thread_id, threading.get_ident()) + self.assertTrue(asyncio.iscoroutine(f2)) def test_reader_callback(self): r, w = test_utils.socketpair()