Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ef27ad5

Browse filesBrowse files
committed
Add support for asynchronous iterables to copy_records_to_table()
The `Connection.copy_records_to_table()` now allows the `records` argument to be an asynchronous iterable. Fixes: #689.
1 parent 075114c commit ef27ad5
Copy full SHA for ef27ad5

File tree

Expand file treeCollapse file tree

3 files changed

+90
-24
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

3 files changed

+90
-24
lines changed
Open diff view settings
Collapse file

‎asyncpg/connection.py‎

Copy file name to clipboardExpand all lines: asyncpg/connection.py
+25-6Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,8 @@ async def copy_records_to_table(self, table_name, *, records,
868868
869869
:param records:
870870
An iterable returning row tuples to copy into the table.
871+
:term:`Asynchronous iterables <python:asynchronous iterable>`
872+
are also supported.
871873
872874
:param list columns:
873875
An optional list of column names to copy.
@@ -897,7 +899,28 @@ async def copy_records_to_table(self, table_name, *, records,
897899
>>> asyncio.get_event_loop().run_until_complete(run())
898900
'COPY 2'
899901
902+
Asynchronous record iterables are also supported:
903+
904+
.. code-block:: pycon
905+
906+
>>> import asyncpg
907+
>>> import asyncio
908+
>>> async def run():
909+
... con = await asyncpg.connect(user='postgres')
910+
... async def record_gen(size):
911+
... for i in range(size):
912+
... yield (i,)
913+
... result = await con.copy_records_to_table(
914+
... 'mytable', records=record_gen(100))
915+
... print(result)
916+
...
917+
>>> asyncio.get_event_loop().run_until_complete(run())
918+
'COPY 100'
919+
900920
.. versionadded:: 0.11.0
921+
922+
.. versionchanged:: 0.23.0
923+
The ``records`` argument may be an asynchronous iterable.
901924
"""
902925
tabname = utils._quote_ident(table_name)
903926
if schema_name:
@@ -920,8 +943,8 @@ async def copy_records_to_table(self, table_name, *, records,
920943
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
921944
tab=tabname, cols=cols, opts=opts)
922945

923-
return await self._copy_in_records(
924-
copy_stmt, records, intro_ps._state, timeout)
946+
return await self._protocol.copy_in(
947+
copy_stmt, None, None, records, intro_ps._state, timeout)
925948

926949
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
927950
delimiter=None, null=None, header=None, quote=None,
@@ -1044,10 +1067,6 @@ async def __anext__(self):
10441067
if opened_by_us:
10451068
await run_in_executor(None, f.close)
10461069

1047-
async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
1048-
return await self._protocol.copy_in(
1049-
copy_stmt, None, None, records, intro_stmt, timeout)
1050-
10511070
async def set_type_codec(self, typename, *,
10521071
schema='public', encoder, decoder,
10531072
format='text'):
Collapse file

‎asyncpg/protocol/protocol.pyx‎

Copy file name to clipboardExpand all lines: asyncpg/protocol/protocol.pyx
+39-18Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ cimport cpython
1313
import asyncio
1414
import builtins
1515
import codecs
16-
import collections
16+
import collections.abc
1717
import socket
1818
import time
1919
import weakref
@@ -438,23 +438,44 @@ cdef class BaseProtocol(CoreProtocol):
438438
'no binary format encoder for '
439439
'type {} (OID {})'.format(codec.name, codec.oid))
440440

441-
for row in records:
442-
# Tuple header
443-
wbuf.write_int16(<int16_t>num_cols)
444-
# Tuple data
445-
for i in range(num_cols):
446-
item = row[i]
447-
if item is None:
448-
wbuf.write_int32(-1)
449-
else:
450-
codec = <Codec>cpython.PyTuple_GET_ITEM(codecs, i)
451-
codec.encode(settings, wbuf, item)
452-
453-
if wbuf.len() >= _COPY_BUFFER_SIZE:
454-
with timer:
455-
await self.writing_allowed.wait()
456-
self._write_copy_data_msg(wbuf)
457-
wbuf = WriteBuffer.new()
441+
if isinstance(records, collections.abc.AsyncIterable):
442+
async for row in records:
443+
# Tuple header
444+
wbuf.write_int16(<int16_t>num_cols)
445+
# Tuple data
446+
for i in range(num_cols):
447+
item = row[i]
448+
if item is None:
449+
wbuf.write_int32(-1)
450+
else:
451+
codec = <Codec>cpython.PyTuple_GET_ITEM(
452+
codecs, i)
453+
codec.encode(settings, wbuf, item)
454+
455+
if wbuf.len() >= _COPY_BUFFER_SIZE:
456+
with timer:
457+
await self.writing_allowed.wait()
458+
self._write_copy_data_msg(wbuf)
459+
wbuf = WriteBuffer.new()
460+
else:
461+
for row in records:
462+
# Tuple header
463+
wbuf.write_int16(<int16_t>num_cols)
464+
# Tuple data
465+
for i in range(num_cols):
466+
item = row[i]
467+
if item is None:
468+
wbuf.write_int32(-1)
469+
else:
470+
codec = <Codec>cpython.PyTuple_GET_ITEM(
471+
codecs, i)
472+
codec.encode(settings, wbuf, item)
473+
474+
if wbuf.len() >= _COPY_BUFFER_SIZE:
475+
with timer:
476+
await self.writing_allowed.wait()
477+
self._write_copy_data_msg(wbuf)
478+
wbuf = WriteBuffer.new()
458479

459480
# End of binary copy.
460481
wbuf.write_int16(-1)
Collapse file

‎tests/test_copy.py‎

Copy file name to clipboardExpand all lines: tests/test_copy.py
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import datetime
1010
import io
1111
import os
12+
import sys
1213
import tempfile
14+
import unittest
1315

1416
import asyncpg
1517
from asyncpg import _testbase as tb
@@ -649,6 +651,30 @@ async def test_copy_records_to_table_1(self):
649651
finally:
650652
await self.con.execute('DROP TABLE copytab')
651653

654+
@unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support')
655+
async def test_copy_records_to_table_async(self):
656+
await self.con.execute('''
657+
CREATE TABLE copytab_async(a text, b int, c timestamptz);
658+
''')
659+
660+
try:
661+
date = datetime.datetime.now(tz=datetime.timezone.utc)
662+
delta = datetime.timedelta(days=1)
663+
664+
async def record_generator():
665+
for i in range(100):
666+
yield ('a-{}'.format(i), i, date + delta)
667+
668+
yield ('a-100', None, None)
669+
670+
res = await self.con.copy_records_to_table(
671+
'copytab_async', records=record_generator())
672+
673+
self.assertEqual(res, 'COPY 101')
674+
675+
finally:
676+
await self.con.execute('DROP TABLE copytab_async')
677+
652678
async def test_copy_records_to_table_no_binary_codec(self):
653679
await self.con.execute('''
654680
CREATE TABLE copytab(a uuid);

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.