Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9f6839b

Browse filesBrowse files
committed
Add support for asynchronous iterables to copy_records_to_table()
The `Connection.copy_records_to_table()` now allows the `records` argument to be an asynchronous iterable. Fixes: #689.
1 parent a308a97 commit 9f6839b
Copy full SHA for 9f6839b

File tree

Expand file treeCollapse file tree

3 files changed

+87
-23
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

3 files changed

+87
-23
lines changed
Open diff view settings
Collapse file

‎asyncpg/connection.py‎

Copy file name to clipboardExpand all lines: asyncpg/connection.py
+25-6Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,8 @@ async def copy_records_to_table(self, table_name, *, records,
868868
869869
:param records:
870870
An iterable returning row tuples to copy into the table.
871+
:term:`Asynchronous iterables <python:asynchronous iterable>`
872+
are also supported.
871873
872874
:param list columns:
873875
An optional list of column names to copy.
@@ -897,7 +899,28 @@ async def copy_records_to_table(self, table_name, *, records,
897899
>>> asyncio.get_event_loop().run_until_complete(run())
898900
'COPY 2'
899901
902+
Asynchronous record iterables are also supported:
903+
904+
.. code-block:: pycon
905+
906+
>>> import asyncpg
907+
>>> import asyncio
908+
>>> async def run():
909+
... con = await asyncpg.connect(user='postgres')
910+
... async def record_gen(size):
911+
... for i in range(size):
912+
... yield (i,)
913+
... result = await con.copy_records_to_table(
914+
... 'mytable', records=record_gen(100))
915+
... print(result)
916+
...
917+
>>> asyncio.get_event_loop().run_until_complete(run())
918+
'COPY 100'
919+
900920
.. versionadded:: 0.11.0
921+
922+
.. versionchanged:: 0.23.0
923+
The ``records`` argument may be an asynchronous iterable.
901924
"""
902925
tabname = utils._quote_ident(table_name)
903926
if schema_name:
@@ -920,8 +943,8 @@ async def copy_records_to_table(self, table_name, *, records,
920943
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
921944
tab=tabname, cols=cols, opts=opts)
922945

923-
return await self._copy_in_records(
924-
copy_stmt, records, intro_ps._state, timeout)
946+
return await self._protocol.copy_in(
947+
copy_stmt, None, None, records, intro_ps._state, timeout)
925948

926949
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
927950
delimiter=None, null=None, header=None, quote=None,
@@ -1044,10 +1067,6 @@ async def __anext__(self):
10441067
if opened_by_us:
10451068
await run_in_executor(None, f.close)
10461069

1047-
async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
1048-
return await self._protocol.copy_in(
1049-
copy_stmt, None, None, records, intro_stmt, timeout)
1050-
10511070
async def set_type_codec(self, typename, *,
10521071
schema='public', encoder, decoder,
10531072
format='text'):
Collapse file

‎asyncpg/protocol/protocol.pyx‎

Copy file name to clipboardExpand all lines: asyncpg/protocol/protocol.pyx
+39-17Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import asyncio
1414
import builtins
1515
import codecs
1616
import collections
17+
import collections.abc
1718
import socket
1819
import time
1920
import weakref
@@ -436,23 +437,44 @@ cdef class BaseProtocol(CoreProtocol):
436437
'no binary format encoder for '
437438
'type {} (OID {})'.format(codec.name, codec.oid))
438439

439-
for row in records:
440-
# Tuple header
441-
wbuf.write_int16(<int16_t>num_cols)
442-
# Tuple data
443-
for i in range(num_cols):
444-
item = row[i]
445-
if item is None:
446-
wbuf.write_int32(-1)
447-
else:
448-
codec = <Codec>cpython.PyTuple_GET_ITEM(codecs, i)
449-
codec.encode(settings, wbuf, item)
450-
451-
if wbuf.len() >= _COPY_BUFFER_SIZE:
452-
with timer:
453-
await self.writing_allowed.wait()
454-
self._write_copy_data_msg(wbuf)
455-
wbuf = WriteBuffer.new()
440+
if isinstance(records, collections.abc.AsyncIterable):
441+
async for row in records:
442+
# Tuple header
443+
wbuf.write_int16(<int16_t>num_cols)
444+
# Tuple data
445+
for i in range(num_cols):
446+
item = row[i]
447+
if item is None:
448+
wbuf.write_int32(-1)
449+
else:
450+
codec = <Codec>cpython.PyTuple_GET_ITEM(
451+
codecs, i)
452+
codec.encode(settings, wbuf, item)
453+
454+
if wbuf.len() >= _COPY_BUFFER_SIZE:
455+
with timer:
456+
await self.writing_allowed.wait()
457+
self._write_copy_data_msg(wbuf)
458+
wbuf = WriteBuffer.new()
459+
else:
460+
for row in records:
461+
# Tuple header
462+
wbuf.write_int16(<int16_t>num_cols)
463+
# Tuple data
464+
for i in range(num_cols):
465+
item = row[i]
466+
if item is None:
467+
wbuf.write_int32(-1)
468+
else:
469+
codec = <Codec>cpython.PyTuple_GET_ITEM(
470+
codecs, i)
471+
codec.encode(settings, wbuf, item)
472+
473+
if wbuf.len() >= _COPY_BUFFER_SIZE:
474+
with timer:
475+
await self.writing_allowed.wait()
476+
self._write_copy_data_msg(wbuf)
477+
wbuf = WriteBuffer.new()
456478

457479
# End of binary copy.
458480
wbuf.write_int16(-1)
Collapse file

‎tests/test_copy.py‎

Copy file name to clipboardExpand all lines: tests/test_copy.py
+23Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,29 @@ async def test_copy_records_to_table_1(self):
649649
finally:
650650
await self.con.execute('DROP TABLE copytab')
651651

652+
async def test_copy_records_to_table_async(self):
653+
await self.con.execute('''
654+
CREATE TABLE copytab_async(a text, b int, c timestamptz);
655+
''')
656+
657+
try:
658+
date = datetime.datetime.now(tz=datetime.timezone.utc)
659+
delta = datetime.timedelta(days=1)
660+
661+
async def record_generator():
662+
for i in range(100):
663+
yield ('a-{}'.format(i), i, date + delta)
664+
665+
yield ('a-100', None, None)
666+
667+
res = await self.con.copy_records_to_table(
668+
'copytab_async', records=record_generator())
669+
670+
self.assertEqual(res, 'COPY 101')
671+
672+
finally:
673+
await self.con.execute('DROP TABLE copytab_async')
674+
652675
async def test_copy_records_to_table_no_binary_codec(self):
653676
await self.con.execute('''
654677
CREATE TABLE copytab(a uuid);

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.