Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c94d6fc

Browse filesBrowse files
committed
Ignore custom data codec for internal introspection
Fixes: #617
1 parent 68b40cb commit c94d6fc
Copy full SHA for c94d6fc

File tree

Expand file treeCollapse file tree

9 files changed

+82
-32
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

9 files changed

+82
-32
lines changed
Open diff view settings
Collapse file

‎asyncpg/connection.py‎

Copy file name to clipboardExpand all lines: asyncpg/connection.py
+33-12Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,16 @@ async def _get_statement(
342342
*,
343343
named: bool=False,
344344
use_cache: bool=True,
345+
ignore_custom_codec=False,
345346
record_class=None
346347
):
347348
if record_class is None:
348349
record_class = self._protocol.get_record_class()
349350

350351
if use_cache:
351-
statement = self._stmt_cache.get((query, record_class))
352+
statement = self._stmt_cache.get(
353+
(query, record_class, ignore_custom_codec)
354+
)
352355
if statement is not None:
353356
return statement
354357

@@ -371,6 +374,7 @@ async def _get_statement(
371374
query,
372375
timeout,
373376
record_class=record_class,
377+
ignore_custom_codec=ignore_custom_codec,
374378
)
375379
need_reprepare = False
376380
types_with_missing_codecs = statement._init_types()
@@ -415,7 +419,8 @@ async def _get_statement(
415419
)
416420

417421
if use_cache:
418-
self._stmt_cache.put((query, record_class), statement)
422+
self._stmt_cache.put(
423+
(query, record_class, ignore_custom_codec), statement)
419424

420425
# If we've just created a new statement object, check if there
421426
# are any statements for GC.
@@ -426,7 +431,12 @@ async def _get_statement(
426431

427432
async def _introspect_types(self, typeoids, timeout):
428433
return await self.__execute(
429-
self._intro_query, (list(typeoids),), 0, timeout)
434+
self._intro_query,
435+
(list(typeoids),),
436+
0,
437+
timeout,
438+
ignore_custom_codec=True,
439+
)
430440

431441
async def _introspect_type(self, typename, schema):
432442
if (
@@ -439,20 +449,22 @@ async def _introspect_type(self, typename, schema):
439449
[typeoid],
440450
limit=0,
441451
timeout=None,
452+
ignore_custom_codec=True,
442453
)
443-
if rows:
444-
typeinfo = rows[0]
445-
else:
446-
typeinfo = None
447454
else:
448-
typeinfo = await self.fetchrow(
449-
introspection.TYPE_BY_NAME, typename, schema)
455+
rows = await self._execute(
456+
introspection.TYPE_BY_NAME,
457+
[typename, schema],
458+
limit=1,
459+
timeout=None,
460+
ignore_custom_codec=True,
461+
)
450462

451-
if not typeinfo:
463+
if not rows:
452464
raise ValueError(
453465
'unknown type: {}.{}'.format(schema, typename))
454466

455-
return typeinfo
467+
return rows[0]
456468

457469
def cursor(
458470
self,
@@ -1325,7 +1337,9 @@ def _mark_stmts_as_closed(self):
13251337
def _maybe_gc_stmt(self, stmt):
13261338
if (
13271339
stmt.refs == 0
1328-
and not self._stmt_cache.has((stmt.query, stmt.record_class))
1340+
and not self._stmt_cache.has(
1341+
(stmt.query, stmt.record_class, stmt.ignore_custom_codec)
1342+
)
13291343
):
13301344
# If low-level `stmt` isn't referenced from any high-level
13311345
# `PreparedStatement` object and is not in the `_stmt_cache`:
@@ -1589,6 +1603,7 @@ async def _execute(
15891603
timeout,
15901604
*,
15911605
return_status=False,
1606+
ignore_custom_codec=False,
15921607
record_class=None
15931608
):
15941609
with self._stmt_exclusive_section:
@@ -1599,6 +1614,7 @@ async def _execute(
15991614
timeout,
16001615
return_status=return_status,
16011616
record_class=record_class,
1617+
ignore_custom_codec=ignore_custom_codec,
16021618
)
16031619
return result
16041620

@@ -1610,6 +1626,7 @@ async def __execute(
16101626
timeout,
16111627
*,
16121628
return_status=False,
1629+
ignore_custom_codec=False,
16131630
record_class=None
16141631
):
16151632
executor = lambda stmt, timeout: self._protocol.bind_execute(
@@ -1620,6 +1637,7 @@ async def __execute(
16201637
executor,
16211638
timeout,
16221639
record_class=record_class,
1640+
ignore_custom_codec=ignore_custom_codec,
16231641
)
16241642

16251643
async def _executemany(self, query, args, timeout):
@@ -1637,20 +1655,23 @@ async def _do_execute(
16371655
timeout,
16381656
retry=True,
16391657
*,
1658+
ignore_custom_codec=False,
16401659
record_class=None
16411660
):
16421661
if timeout is None:
16431662
stmt = await self._get_statement(
16441663
query,
16451664
None,
16461665
record_class=record_class,
1666+
ignore_custom_codec=ignore_custom_codec,
16471667
)
16481668
else:
16491669
before = time.monotonic()
16501670
stmt = await self._get_statement(
16511671
query,
16521672
timeout,
16531673
record_class=record_class,
1674+
ignore_custom_codec=ignore_custom_codec,
16541675
)
16551676
after = time.monotonic()
16561677
timeout -= after - before
Collapse file

‎asyncpg/protocol/codecs/base.pxd‎

Copy file name to clipboardExpand all lines: asyncpg/protocol/codecs/base.pxd
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,6 @@ cdef class DataCodecConfig:
166166
dict _derived_type_codecs
167167
dict _custom_type_codecs
168168

169-
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format)
169+
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
170+
bint ignore_custom_codec=*)
170171
cdef inline Codec get_any_local_codec(self, uint32_t oid)
Collapse file

‎asyncpg/protocol/codecs/base.pyx‎

Copy file name to clipboardExpand all lines: asyncpg/protocol/codecs/base.pyx
+12-10Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -692,18 +692,20 @@ cdef class DataCodecConfig:
692692

693693
return codec
694694

695-
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format):
695+
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
696+
bint ignore_custom_codec=False):
696697
cdef Codec codec
697698

698-
codec = self.get_any_local_codec(oid)
699-
if codec is not None:
700-
if codec.format != format:
701-
# The codec for this OID has been overridden by
702-
# set_{builtin}_type_codec with a different format.
703-
# We must respect that and not return a core codec.
704-
return None
705-
else:
706-
return codec
699+
if not ignore_custom_codec:
700+
codec = self.get_any_local_codec(oid)
701+
if codec is not None:
702+
if codec.format != format:
703+
# The codec for this OID has been overridden by
704+
# set_{builtin}_type_codec with a different format.
705+
# We must respect that and not return a core codec.
706+
return None
707+
else:
708+
return codec
707709

708710
codec = get_core_codec(oid, format)
709711
if codec is not None:
Collapse file

‎asyncpg/protocol/prepared_stmt.pxd‎

Copy file name to clipboardExpand all lines: asyncpg/protocol/prepared_stmt.pxd
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cdef class PreparedStatementState:
1212
readonly bint closed
1313
readonly int refs
1414
readonly type record_class
15+
readonly bint ignore_custom_codec
1516

1617

1718
list row_desc
Collapse file

‎asyncpg/protocol/prepared_stmt.pyx‎

Copy file name to clipboardExpand all lines: asyncpg/protocol/prepared_stmt.pyx
+7-3Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ cdef class PreparedStatementState:
1616
str name,
1717
str query,
1818
BaseProtocol protocol,
19-
type record_class
19+
type record_class,
20+
bint ignore_custom_codec
2021
):
2122
self.name = name
2223
self.query = query
@@ -28,6 +29,7 @@ cdef class PreparedStatementState:
2829
self.closed = False
2930
self.refs = 0
3031
self.record_class = record_class
32+
self.ignore_custom_codec = ignore_custom_codec
3133

3234
def _get_parameters(self):
3335
cdef Codec codec
@@ -205,7 +207,8 @@ cdef class PreparedStatementState:
205207
cols_mapping[col_name] = i
206208
cols_names.append(col_name)
207209
oid = row[3]
208-
codec = self.settings.get_data_codec(oid)
210+
codec = self.settings.get_data_codec(
211+
oid, ignore_custom_codec=self.ignore_custom_codec)
209212
if codec is None or not codec.has_decoder():
210213
raise exceptions.InternalClientError(
211214
'no decoder for OID {}'.format(oid))
@@ -230,7 +233,8 @@ cdef class PreparedStatementState:
230233

231234
for i from 0 <= i < self.args_num:
232235
p_oid = self.parameters_desc[i]
233-
codec = self.settings.get_data_codec(p_oid)
236+
codec = self.settings.get_data_codec(
237+
p_oid, ignore_custom_codec=self.ignore_custom_codec)
234238
if codec is None or not codec.has_encoder():
235239
raise exceptions.InternalClientError(
236240
'no encoder for OID {}'.format(p_oid))
Collapse file

‎asyncpg/protocol/protocol.pyx‎

Copy file name to clipboardExpand all lines: asyncpg/protocol/protocol.pyx
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ cdef class BaseProtocol(CoreProtocol):
145145
async def prepare(self, stmt_name, query, timeout,
146146
*,
147147
PreparedStatementState state=None,
148+
ignore_custom_codec=False,
148149
record_class):
149150
if self.cancel_waiter is not None:
150151
await self.cancel_waiter
@@ -161,7 +162,7 @@ cdef class BaseProtocol(CoreProtocol):
161162
self.last_query = query
162163
if state is None:
163164
state = PreparedStatementState(
164-
stmt_name, query, self, record_class)
165+
stmt_name, query, self, record_class, ignore_custom_codec)
165166
self.statement = state
166167
except Exception as ex:
167168
waiter.set_exception(ex)
Collapse file

‎asyncpg/protocol/settings.pxd‎

Copy file name to clipboardExpand all lines: asyncpg/protocol/settings.pxd
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ cdef class ConnectionSettings(pgproto.CodecContext):
2626
cpdef inline set_builtin_type_codec(
2727
self, typeoid, typename, typeschema, typekind, alias_to, format)
2828
cpdef inline Codec get_data_codec(
29-
self, uint32_t oid, ServerDataFormat format=*)
29+
self, uint32_t oid, ServerDataFormat format=*,
30+
bint ignore_custom_codec=*)
Collapse file

‎asyncpg/protocol/settings.pyx‎

Copy file name to clipboardExpand all lines: asyncpg/protocol/settings.pyx
+8-4Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,18 @@ cdef class ConnectionSettings(pgproto.CodecContext):
8787
typekind, alias_to, _format)
8888

8989
cpdef inline Codec get_data_codec(self, uint32_t oid,
90-
ServerDataFormat format=PG_FORMAT_ANY):
90+
ServerDataFormat format=PG_FORMAT_ANY,
91+
bint ignore_custom_codec=False):
9192
if format == PG_FORMAT_ANY:
92-
codec = self._data_codecs.get_codec(oid, PG_FORMAT_BINARY)
93+
codec = self._data_codecs.get_codec(
94+
oid, PG_FORMAT_BINARY, ignore_custom_codec)
9395
if codec is None:
94-
codec = self._data_codecs.get_codec(oid, PG_FORMAT_TEXT)
96+
codec = self._data_codecs.get_codec(
97+
oid, PG_FORMAT_TEXT, ignore_custom_codec)
9598
return codec
9699
else:
97-
return self._data_codecs.get_codec(oid, format)
100+
return self._data_codecs.get_codec(
101+
oid, format, ignore_custom_codec)
98102

99103
def __getattr__(self, name):
100104
if not name.startswith('_'):
Collapse file

‎tests/test_introspection.py‎

Copy file name to clipboardExpand all lines: tests/test_introspection.py
+15Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,20 @@ def tearDownClass(cls):
4343

4444
super().tearDownClass()
4545

46+
def setUp(self):
47+
super().setUp()
48+
self.loop.run_until_complete(self._add_custom_codec(self.con))
49+
50+
async def _add_custom_codec(self, conn):
51+
# mess up with the codec - builtin introspection shouldn't be affected
52+
await conn.set_type_codec(
53+
"oid",
54+
schema="pg_catalog",
55+
encoder=lambda value: None,
56+
decoder=lambda value: None,
57+
format="text",
58+
)
59+
4660
@tb.with_connection_options(database='asyncpg_intro_test')
4761
async def test_introspection_on_large_db(self):
4862
await self.con.execute(
@@ -142,6 +156,7 @@ async def test_introspection_retries_after_cache_bust(self):
142156
# query would cause introspection to retry.
143157
slow_intro_conn = await self.connect(
144158
connection_class=SlowIntrospectionConnection)
159+
await self._add_custom_codec(slow_intro_conn)
145160
try:
146161
await self.con.execute('''
147162
CREATE DOMAIN intro_1_t AS int;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.