Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 4444edd

Browse filesBrowse files
committed
Add middleware support
1 parent 851d586 commit 4444edd
Copy full SHA for 4444edd

File tree

Expand file treeCollapse file tree

6 files changed

+107
-12
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+107
-12
lines changed

‎asyncpg/_testbase/__init__.py

Copy file name to clipboardExpand all lines: asyncpg/_testbase/__init__.py
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def create_pool(dsn=None, *,
264264
setup=None,
265265
init=None,
266266
loop=None,
267+
middlewares=None,
267268
pool_class=pg_pool.Pool,
268269
connection_class=pg_connection.Connection,
269270
**connect_kwargs):
@@ -272,7 +273,7 @@ def create_pool(dsn=None, *,
272273
min_size=min_size, max_size=max_size,
273274
max_queries=max_queries, loop=loop, setup=setup, init=init,
274275
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
275-
connection_class=connection_class,
276+
connection_class=connection_class, middlewares=middlewares,
276277
**connect_kwargs)
277278

278279

‎asyncpg/connect_utils.py

Copy file name to clipboardExpand all lines: asyncpg/connect_utils.py
+4-3Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
594594

595595

596596
async def _connect_addr(*, addr, loop, timeout, params, config,
597-
connection_class):
597+
middlewares, connection_class):
598598
assert loop is not None
599599

600600
if timeout <= 0:
@@ -633,12 +633,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
633633
tr.close()
634634
raise
635635

636-
con = connection_class(pr, tr, loop, addr, config, params)
636+
con = connection_class(pr, tr, loop, addr, config, params, middlewares)
637637
pr.set_connection(con)
638638
return con
639639

640640

641-
async def _connect(*, loop, timeout, connection_class, **kwargs):
641+
async def _connect(*, loop, timeout, middlewares, connection_class, **kwargs):
642642
if loop is None:
643643
loop = asyncio.get_event_loop()
644644

@@ -652,6 +652,7 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
652652
con = await _connect_addr(
653653
addr=addr, loop=loop, timeout=timeout,
654654
params=params, config=config,
655+
middlewares=middlewares,
655656
connection_class=connection_class)
656657
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
657658
last_error = ex

‎asyncpg/connection.py

Copy file name to clipboardExpand all lines: asyncpg/connection.py
+16-5Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta):
4141
"""
4242

4343
__slots__ = ('_protocol', '_transport', '_loop',
44-
'_top_xact', '_aborted',
44+
'_top_xact', '_aborted', '_middlewares',
4545
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
4646
'_listeners', '_server_version', '_server_caps',
4747
'_intro_query', '_reset_query', '_proxy',
@@ -52,7 +52,8 @@ class Connection(metaclass=ConnectionMeta):
5252
def __init__(self, protocol, transport, loop,
5353
addr: (str, int) or str,
5454
config: connect_utils._ClientConfiguration,
55-
params: connect_utils._ConnectionParameters):
55+
params: connect_utils._ConnectionParameters,
56+
_middlewares=None):
5657
self._protocol = protocol
5758
self._transport = transport
5859
self._loop = loop
@@ -91,7 +92,7 @@ def __init__(self, protocol, transport, loop,
9192

9293
self._reset_query = None
9394
self._proxy = None
94-
95+
self._middlewares = _middlewares
9596
# Used to serialize operations that might involve anonymous
9697
# statements. Specifically, we want to make the following
9798
# operation atomic:
@@ -1399,8 +1400,12 @@ async def reload_schema_state(self):
13991400

14001401
async def _execute(self, query, args, limit, timeout, return_status=False):
14011402
with self._stmt_exclusive_section:
1402-
result, _ = await self.__execute(
1403-
query, args, limit, timeout, return_status=return_status)
1403+
wrapped = self.__execute
1404+
if self._middlewares:
1405+
for m in reversed(self._middlewares):
1406+
wrapped = await m(connection=self, handler=wrapped)
1407+
1408+
result, _ = await wrapped(query, args, limit, timeout, return_status=return_status)
14041409
return result
14051410

14061411
async def __execute(self, query, args, limit, timeout,
@@ -1491,6 +1496,7 @@ async def connect(dsn=None, *,
14911496
max_cacheable_statement_size=1024 * 15,
14921497
command_timeout=None,
14931498
ssl=None,
1499+
middlewares=None,
14941500
connection_class=Connection,
14951501
server_settings=None):
14961502
r"""A coroutine to establish a connection to a PostgreSQL server.
@@ -1607,6 +1613,10 @@ async def connect(dsn=None, *,
16071613
PostgreSQL documentation for
16081614
a `list of supported options <server settings>`_.
16091615
1616+
:param middlewares:
1617+
An optional list of middleware functions. Refer to documentation
1618+
on create_pool.
1619+
16101620
:param Connection connection_class:
16111621
Class of the returned connection object. Must be a subclass of
16121622
:class:`~asyncpg.connection.Connection`.
@@ -1672,6 +1682,7 @@ async def connect(dsn=None, *,
16721682
ssl=ssl, database=database,
16731683
server_settings=server_settings,
16741684
command_timeout=command_timeout,
1685+
middlewares=middlewares,
16751686
statement_cache_size=statement_cache_size,
16761687
max_cached_statement_lifetime=max_cached_statement_lifetime,
16771688
max_cacheable_statement_size=max_cacheable_statement_size)

‎asyncpg/pool.py

Copy file name to clipboardExpand all lines: asyncpg/pool.py
+43-2Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ class Pool:
305305
"""
306306

307307
__slots__ = (
308-
'_queue', '_loop', '_minsize', '_maxsize',
308+
'_queue', '_loop', '_minsize', '_maxsize', '_middlewares',
309309
'_init', '_connect_args', '_connect_kwargs',
310310
'_working_addr', '_working_config', '_working_params',
311311
'_holders', '_initialized', '_initializing', '_closing',
@@ -320,6 +320,7 @@ def __init__(self, *connect_args,
320320
max_inactive_connection_lifetime,
321321
setup,
322322
init,
323+
middlewares,
323324
loop,
324325
connection_class,
325326
**connect_kwargs):
@@ -377,6 +378,7 @@ def __init__(self, *connect_args,
377378
self._closed = False
378379
self._generation = 0
379380
self._init = init
381+
self._middlewares = middlewares
380382
self._connect_args = connect_args
381383
self._connect_kwargs = connect_kwargs
382384

@@ -469,6 +471,7 @@ async def _get_new_connection(self):
469471
*self._connect_args,
470472
loop=self._loop,
471473
connection_class=self._connection_class,
474+
middlewares=self._middlewares,
472475
**self._connect_kwargs)
473476

474477
self._working_addr = con._addr
@@ -483,6 +486,7 @@ async def _get_new_connection(self):
483486
addr=self._working_addr,
484487
timeout=self._working_params.connect_timeout,
485488
config=self._working_config,
489+
middlewares=self._middlewares,
486490
params=self._working_params,
487491
connection_class=self._connection_class)
488492

@@ -784,13 +788,35 @@ def __await__(self):
784788
return self.pool._acquire(self.timeout).__await__()
785789

786790

791+
def middleware(f):
792+
"""Decorator for adding a middleware
793+
794+
Can be used like such
795+
796+
.. code-block:: python
797+
798+
@pool.middleware
799+
async def my_middleware(query, args, limit, timeout, return_status, *, handler, conn):
800+
print('do something before')
801+
result, stmt = await handler(query, args, limit, timeout, return_status)
802+
print('do something after')
803+
return result, stmt
804+
805+
my_pool = await pool.create_pool(middlewares=[my_middleware])
806+
"""
807+
async def middleware_factory(connection, handler):
808+
return functools.partial(f, connection=connection, handler=handler)
809+
return middleware_factory
810+
811+
787812
def create_pool(dsn=None, *,
788813
min_size=10,
789814
max_size=10,
790815
max_queries=50000,
791816
max_inactive_connection_lifetime=300.0,
792817
setup=None,
793818
init=None,
819+
middlewares=None,
794820
loop=None,
795821
connection_class=connection.Connection,
796822
**connect_kwargs):
@@ -866,6 +892,19 @@ def create_pool(dsn=None, *,
866892
or :meth:`Connection.set_type_codec() <\
867893
asyncpg.connection.Connection.set_type_codec>`.
868894
895+
:param middlewares:
896+
A list of middleware functions to be middleware just
897+
before a connection excecutes a statement.
898+
Syntax of a middleware is as follows:
899+
async def middleware_factory(connection, handler):
900+
async def middleware(query, args, limit, timeout, return_status):
901+
print('do something before')
902+
result, stmt = await handler(query, args, limit,
903+
timeout, return_status)
904+
print('do something after')
905+
return result, stmt
906+
return middleware
907+
869908
:param loop:
870909
An asyncio event loop instance. If ``None``, the default
871910
event loop will be used.
@@ -893,6 +932,8 @@ def create_pool(dsn=None, *,
893932
dsn,
894933
connection_class=connection_class,
895934
min_size=min_size, max_size=max_size,
896-
max_queries=max_queries, loop=loop, setup=setup, init=init,
935+
max_queries=max_queries, loop=loop, setup=setup,
936+
middlewares=middlewares, init=init,
897937
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
898938
**connect_kwargs)
939+

‎docs/installation.rst

Copy file name to clipboardExpand all lines: docs/installation.rst
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ If you want to build **asyncpg** from a Git checkout you will need:
3030
* CPython header files. These can usually be obtained by installing
3131
the relevant Python development package: **python3-dev** on Debian/Ubuntu,
3232
**python3-devel** on RHEL/Fedora.
33-
33+
* Clone the repo with submodules (`git clone --recursive`, or `git submodules init; git submodules update`)
3434
Once the above requirements are satisfied, run the following command
3535
in the root of the source checkout:
3636

‎tests/test_pool.py

Copy file name to clipboardExpand all lines: tests/test_pool.py
+41Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,47 @@ async def worker():
7676
tasks = [worker() for _ in range(n)]
7777
await asyncio.gather(*tasks)
7878

79+
async def test_pool_with_middleware(self):
80+
called = False
81+
82+
async def my_middleware_factory(connection, handler):
83+
async def middleware(query, args, limit, timeout, return_status):
84+
nonlocal called
85+
called = True
86+
return await handler(query, args, limit,
87+
timeout, return_status)
88+
return middleware
89+
90+
pool = await self.create_pool(database='postgres',
91+
min_size=1, max_size=1,
92+
middlewares=[my_middleware_factory])
93+
94+
con = await pool.acquire(timeout=5)
95+
await con.fetchval('SELECT 1')
96+
assert called
97+
98+
pool.terminate()
99+
del con
100+
101+
async def test_pool_with_middleware_decorator(self):
102+
called = False
103+
104+
@pg_pool.middleware
105+
async def my_middleware(query, args, limit, timeout, return_status,
106+
*, connection, handler):
107+
nonlocal called
108+
called = True
109+
return await handler(query, args, limit,
110+
timeout, return_status)
111+
112+
pool = await self.create_pool(database='postgres', min_size=1, max_size=1, middlewares=[my_middleware])
113+
con = await pool.acquire(timeout=5)
114+
await con.fetchval('SELECT 1')
115+
assert called
116+
117+
pool.terminate()
118+
del con
119+
79120
async def test_pool_03(self):
80121
pool = await self.create_pool(database='postgres',
81122
min_size=1, max_size=1)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.