Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1d33ff6

Browse filesBrowse files
authored
Add support for asynchronous iterables to copy_records_to_table() (#713)
The `Connection.copy_records_to_table()` now allows the `records` argument to be an asynchronous iterable. Fixes: #689.
1 parent a6b0f28 commit 1d33ff6
Copy full SHA for 1d33ff6

File tree

Expand file treeCollapse file tree

3 files changed

+87
-24
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

3 files changed

+87
-24
lines changed
Open diff view settings
Collapse file

‎asyncpg/connection.py‎

Copy file name to clipboardExpand all lines: asyncpg/connection.py
+25-6Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,8 @@ async def copy_records_to_table(self, table_name, *, records,
872872
873873
:param records:
874874
An iterable returning row tuples to copy into the table.
875+
:term:`Asynchronous iterables <python:asynchronous iterable>`
876+
are also supported.
875877
876878
:param list columns:
877879
An optional list of column names to copy.
@@ -901,7 +903,28 @@ async def copy_records_to_table(self, table_name, *, records,
901903
>>> asyncio.get_event_loop().run_until_complete(run())
902904
'COPY 2'
903905
906+
Asynchronous record iterables are also supported:
907+
908+
.. code-block:: pycon
909+
910+
>>> import asyncpg
911+
>>> import asyncio
912+
>>> async def run():
913+
... con = await asyncpg.connect(user='postgres')
914+
... async def record_gen(size):
915+
... for i in range(size):
916+
... yield (i,)
917+
... result = await con.copy_records_to_table(
918+
... 'mytable', records=record_gen(100))
919+
... print(result)
920+
...
921+
>>> asyncio.get_event_loop().run_until_complete(run())
922+
'COPY 100'
923+
904924
.. versionadded:: 0.11.0
925+
926+
.. versionchanged:: 0.24.0
927+
The ``records`` argument may be an asynchronous iterable.
905928
"""
906929
tabname = utils._quote_ident(table_name)
907930
if schema_name:
@@ -924,8 +947,8 @@ async def copy_records_to_table(self, table_name, *, records,
924947
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
925948
tab=tabname, cols=cols, opts=opts)
926949

927-
return await self._copy_in_records(
928-
copy_stmt, records, intro_ps._state, timeout)
950+
return await self._protocol.copy_in(
951+
copy_stmt, None, None, records, intro_ps._state, timeout)
929952

930953
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
931954
delimiter=None, null=None, header=None, quote=None,
@@ -1047,10 +1070,6 @@ async def __anext__(self):
10471070
if opened_by_us:
10481071
await run_in_executor(None, f.close)
10491072

1050-
async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
1051-
return await self._protocol.copy_in(
1052-
copy_stmt, None, None, records, intro_stmt, timeout)
1053-
10541073
async def set_type_codec(self, typename, *,
10551074
schema='public', encoder, decoder,
10561075
format='text'):
Collapse file

‎asyncpg/protocol/protocol.pyx‎

Copy file name to clipboardExpand all lines: asyncpg/protocol/protocol.pyx
+39-18Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ cimport cpython
1313
import asyncio
1414
import builtins
1515
import codecs
16-
import collections
16+
import collections.abc
1717
import socket
1818
import time
1919
import weakref
@@ -438,23 +438,44 @@ cdef class BaseProtocol(CoreProtocol):
438438
'no binary format encoder for '
439439
'type {} (OID {})'.format(codec.name, codec.oid))
440440

441-
for row in records:
442-
# Tuple header
443-
wbuf.write_int16(<int16_t>num_cols)
444-
# Tuple data
445-
for i in range(num_cols):
446-
item = row[i]
447-
if item is None:
448-
wbuf.write_int32(-1)
449-
else:
450-
codec = <Codec>cpython.PyTuple_GET_ITEM(codecs, i)
451-
codec.encode(settings, wbuf, item)
452-
453-
if wbuf.len() >= _COPY_BUFFER_SIZE:
454-
with timer:
455-
await self.writing_allowed.wait()
456-
self._write_copy_data_msg(wbuf)
457-
wbuf = WriteBuffer.new()
441+
if isinstance(records, collections.abc.AsyncIterable):
442+
async for row in records:
443+
# Tuple header
444+
wbuf.write_int16(<int16_t>num_cols)
445+
# Tuple data
446+
for i in range(num_cols):
447+
item = row[i]
448+
if item is None:
449+
wbuf.write_int32(-1)
450+
else:
451+
codec = <Codec>cpython.PyTuple_GET_ITEM(
452+
codecs, i)
453+
codec.encode(settings, wbuf, item)
454+
455+
if wbuf.len() >= _COPY_BUFFER_SIZE:
456+
with timer:
457+
await self.writing_allowed.wait()
458+
self._write_copy_data_msg(wbuf)
459+
wbuf = WriteBuffer.new()
460+
else:
461+
for row in records:
462+
# Tuple header
463+
wbuf.write_int16(<int16_t>num_cols)
464+
# Tuple data
465+
for i in range(num_cols):
466+
item = row[i]
467+
if item is None:
468+
wbuf.write_int32(-1)
469+
else:
470+
codec = <Codec>cpython.PyTuple_GET_ITEM(
471+
codecs, i)
472+
codec.encode(settings, wbuf, item)
473+
474+
if wbuf.len() >= _COPY_BUFFER_SIZE:
475+
with timer:
476+
await self.writing_allowed.wait()
477+
self._write_copy_data_msg(wbuf)
478+
wbuf = WriteBuffer.new()
458479

459480
# End of binary copy.
460481
wbuf.write_int16(-1)
Collapse file

‎tests/test_copy.py‎

Copy file name to clipboardExpand all lines: tests/test_copy.py
+23Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,29 @@ async def test_copy_records_to_table_1(self):
644644
finally:
645645
await self.con.execute('DROP TABLE copytab')
646646

647+
async def test_copy_records_to_table_async(self):
648+
await self.con.execute('''
649+
CREATE TABLE copytab_async(a text, b int, c timestamptz);
650+
''')
651+
652+
try:
653+
date = datetime.datetime.now(tz=datetime.timezone.utc)
654+
delta = datetime.timedelta(days=1)
655+
656+
async def record_generator():
657+
for i in range(100):
658+
yield ('a-{}'.format(i), i, date + delta)
659+
660+
yield ('a-100', None, None)
661+
662+
res = await self.con.copy_records_to_table(
663+
'copytab_async', records=record_generator())
664+
665+
self.assertEqual(res, 'COPY 101')
666+
667+
finally:
668+
await self.con.execute('DROP TABLE copytab_async')
669+
647670
async def test_copy_records_to_table_no_binary_codec(self):
648671
await self.con.execute('''
649672
CREATE TABLE copytab(a uuid);

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.