Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c8f936b

Browse filesBrowse files
authored
fix(sessions): resolve async deadlock in multiplexed session manager (#1520)
This PR resolves a critical deadlock issue when acquiring or maintaining a multiplexed session asynchronously. The bug occurs because DatabaseSessionsManager previously used a synchronous threading.Lock around self._get_multiplexed_session() and _maintain_multiplexed_session(). When a thread attempts to await the multiplexed session creation (return await ...) while holding a synchronous thread lock, the entire asyncio event loop becomes blocked for any other coroutine trying to access the lock.
1 parent f822fd7 commit c8f936b
Copy full SHA for c8f936b

12 files changed

+113-48Lines changed: 113 additions & 48 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/_async/database_sessions_manager.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/_async/database_sessions_manager.py
+13-6Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ def __init__(self, database, pool):
7171
self._pool = pool
7272
self._multiplexed_session: Optional[Session] = None
7373
self._multiplexed_session_thread: Optional[CrossSync.Task] = None
74-
# Use threading.Lock because this is accessed in a synchronous maintenance thread
75-
self._multiplexed_session_lock: threading.Lock = threading.Lock()
76-
self._multiplexed_session_terminate_event: CrossSync.Event = CrossSync.Event()
74+
self._init_lock = threading.Lock()
75+
self._multiplexed_session_lock: Optional[CrossSync.Lock] = None
76+
self._multiplexed_session_terminate_event: Optional[CrossSync.Event] = None
7777

7878
@CrossSync.convert
7979
async def get_session(self, transaction_type: TransactionType) -> Session:
@@ -119,7 +119,13 @@ async def _get_multiplexed_session(self) -> Session:
119119
120120
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
121121
:returns: a multiplexed session."""
122-
with CrossSync.rm_aio(self._multiplexed_session_lock):
122+
with self._init_lock:
123+
if self._multiplexed_session_lock is None:
124+
self._multiplexed_session_lock = CrossSync.Lock()
125+
if self._multiplexed_session_terminate_event is None:
126+
self._multiplexed_session_terminate_event = CrossSync.Event()
127+
128+
async with self._multiplexed_session_lock:
123129
if self._multiplexed_session is None:
124130
self._multiplexed_session = await self._build_multiplexed_session()
125131
self._multiplexed_session_thread = self._build_maintenance_thread()
@@ -193,7 +199,7 @@ async def _maintain_multiplexed_session(session_manager_ref) -> None:
193199
if time() - session_created_time < refresh_interval_seconds:
194200
await CrossSync.sleep(polling_interval_seconds)
195201
continue
196-
with manager._multiplexed_session_lock:
202+
async with manager._multiplexed_session_lock:
197203
await CrossSync.run_if_async(manager._multiplexed_session.delete)
198204
manager._multiplexed_session = (
199205
await manager._build_multiplexed_session()
@@ -220,7 +226,8 @@ def _getenv(cls, env_var_name: str) -> bool:
220226
@CrossSync.convert
221227
async def close(self) -> None:
222228
"""Closes the database session manager and stops all background tasks."""
223-
self._multiplexed_session_terminate_event.set()
229+
if self._multiplexed_session_terminate_event is not None:
230+
self._multiplexed_session_terminate_event.set()
224231
if self._multiplexed_session_thread is not None:
225232
if CrossSync.is_async:
226233
self._multiplexed_session_thread.cancel()
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/batch.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/batch.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def wrapped_method():
243243
max_commit_delay=max_commit_delay,
244244
request_options=request_options,
245245
)
246-
(call_metadata, error_augmenter) = database.with_error_augmentation(
246+
call_metadata, error_augmenter = database.with_error_augmentation(
247247
getattr(database, "_next_nth_request", 0), 1, metadata, span
248248
)
249249
commit_method = functools.partial(
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/database.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/database.py
+5-8Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@
8282
trace_call,
8383
)
8484
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
85-
8685
from google.cloud.spanner_v1.table import Table
8786

8887
SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data"
@@ -211,11 +210,9 @@ def __init__(
211210
def _resource_info(self):
212211
"""Resource information for metrics labels."""
213212
return {
214-
"project": (
215-
self._instance._client.project
216-
if self._instance and self._instance._client
217-
else None
218-
),
213+
"project": self._instance._client.project
214+
if self._instance and self._instance._client
215+
else None,
219216
"instance": self._instance.instance_id if self._instance else None,
220217
"database": self.database_id,
221218
}
@@ -533,7 +530,7 @@ def with_error_augmentation(
533530
tuple: (metadata_list, context_manager)"""
534531
if span is None:
535532
span = get_current_span()
536-
(metadata, request_id) = _metadata_with_request_id_and_req_id(
533+
metadata, request_id = _metadata_with_request_id_and_req_id(
537534
self._nth_client_id,
538535
self._channel_id,
539536
nth_request,
@@ -810,7 +807,7 @@ def execute_pdml():
810807
session = self._sessions_manager.get_session(transaction_type)
811808
try:
812809
add_span_event(span, "Starting BeginTransaction")
813-
(call_metadata, error_augmenter) = self.with_error_augmentation(
810+
call_metadata, error_augmenter = self.with_error_augmentation(
814811
self._next_nth_request, 1, metadata, span
815812
)
816813
with error_augmenter:
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/database_sessions_manager.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/database_sessions_manager.py
+12-5Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,11 @@ def __init__(self, database, pool):
6969
self._pool = pool
7070
self._multiplexed_session: Optional[Session] = None
7171
self._multiplexed_session_thread: Optional[CrossSync._Sync_Impl.Task] = None
72-
self._multiplexed_session_lock: threading.Lock = threading.Lock()
73-
self._multiplexed_session_terminate_event: CrossSync._Sync_Impl.Event = (
74-
CrossSync._Sync_Impl.Event()
75-
)
72+
self._init_lock = threading.Lock()
73+
self._multiplexed_session_lock: Optional[CrossSync._Sync_Impl.Lock] = None
74+
self._multiplexed_session_terminate_event: Optional[
75+
CrossSync._Sync_Impl.Event
76+
] = None
7677

7778
def get_session(self, transaction_type: TransactionType) -> Session:
7879
"""Returns a session for the given transaction type from the database session manager.
@@ -115,6 +116,11 @@ def _get_multiplexed_session(self) -> Session:
115116
116117
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
117118
:returns: a multiplexed session."""
119+
with self._init_lock:
120+
if self._multiplexed_session_lock is None:
121+
self._multiplexed_session_lock = CrossSync._Sync_Impl.Lock()
122+
if self._multiplexed_session_terminate_event is None:
123+
self._multiplexed_session_terminate_event = CrossSync._Sync_Impl.Event()
118124
with self._multiplexed_session_lock:
119125
if self._multiplexed_session is None:
120126
self._multiplexed_session = self._build_multiplexed_session()
@@ -205,7 +211,8 @@ def _getenv(cls, env_var_name: str) -> bool:
205211

206212
def close(self) -> None:
207213
"""Closes the database session manager and stops all background tasks."""
208-
self._multiplexed_session_terminate_event.set()
214+
if self._multiplexed_session_terminate_event is not None:
215+
self._multiplexed_session_terminate_event.set()
209216
if self._multiplexed_session_thread is not None:
210217
self._multiplexed_session_thread.join()
211218
if self._multiplexed_session is not None:
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/instance.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/instance.py
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,9 @@ def database(
479479
database_role=database_role,
480480
enable_drop_protection=enable_drop_protection,
481481
)
482-
db._pool.bind(db)
482+
res = db._pool.bind(db)
483+
if res is not None:
484+
res
483485
return db
484486

485487
def list_databases(self, page_size=None):
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/pool.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/pool.py
+5-7Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def _fill_pool(self):
304304
f"Creating {request.session_count} sessions",
305305
span_event_attributes,
306306
)
307-
(call_metadata, error_augmenter) = database.with_error_augmentation(
307+
call_metadata, error_augmenter = database.with_error_augmentation(
308308
database._next_nth_request, 1, metadata, span
309309
)
310310
with error_augmenter:
@@ -612,7 +612,7 @@ def bind(self, database):
612612
) as span, MetricsCapture(self._resource_info):
613613
returned_session_count = 0
614614
while returned_session_count < self.size:
615-
(call_metadata, error_augmenter) = database.with_error_augmentation(
615+
call_metadata, error_augmenter = database.with_error_augmentation(
616616
database._next_nth_request, 1, metadata, span
617617
)
618618
with error_augmenter:
@@ -654,7 +654,7 @@ def get(self, timeout=None):
654654
ping_after = None
655655
session = None
656656
try:
657-
(ping_after, session) = CrossSync._Sync_Impl.queue_get(
657+
ping_after, session = CrossSync._Sync_Impl.queue_get(
658658
self._sessions, block=True, timeout=timeout
659659
)
660660
except CrossSync._Sync_Impl.QueueEmpty as e:
@@ -698,9 +698,7 @@ def clear(self):
698698
"""Delete all sessions in the pool."""
699699
while True:
700700
try:
701-
(_, session) = CrossSync._Sync_Impl.queue_get(
702-
self._sessions, block=False
703-
)
701+
_, session = CrossSync._Sync_Impl.queue_get(self._sessions, block=False)
704702
except CrossSync._Sync_Impl.QueueEmpty:
705703
break
706704
else:
@@ -713,7 +711,7 @@ def ping(self):
713711
or during the "idle" phase of an event loop."""
714712
while True:
715713
try:
716-
(ping_after, session) = CrossSync._Sync_Impl.queue_get(
714+
ping_after, session = CrossSync._Sync_Impl.queue_get(
717715
self._sessions, block=False
718716
)
719717
except CrossSync._Sync_Impl.QueueEmpty:
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/session.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/session.py
+4-4Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def create(self):
188188
observability_options=observability_options,
189189
metadata=metadata,
190190
) as span, MetricsCapture(self._resource_info):
191-
(call_metadata, error_augmenter) = database.with_error_augmentation(
191+
call_metadata, error_augmenter = database.with_error_augmentation(
192192
nth_request, 1, metadata, span
193193
)
194194
with error_augmenter:
@@ -232,7 +232,7 @@ def exists(self):
232232
observability_options=observability_options,
233233
metadata=metadata,
234234
) as span, MetricsCapture(self._resource_info):
235-
(call_metadata, error_augmenter) = database.with_error_augmentation(
235+
call_metadata, error_augmenter = database.with_error_augmentation(
236236
nth_request, 1, metadata, span
237237
)
238238
with error_augmenter:
@@ -283,7 +283,7 @@ def delete(self):
283283
observability_options=observability_options,
284284
metadata=metadata,
285285
) as span, MetricsCapture(self._resource_info):
286-
(call_metadata, error_augmenter) = database.with_error_augmentation(
286+
call_metadata, error_augmenter = database.with_error_augmentation(
287287
nth_request, 1, metadata, span
288288
)
289289
with error_augmenter:
@@ -300,7 +300,7 @@ def ping(self):
300300
metadata = _metadata_with_prefix(database.name)
301301
nth_request = database._next_nth_request
302302
with trace_call("CloudSpanner.Session.ping", self) as span:
303-
(call_metadata, error_augmenter) = database.with_error_augmentation(
303+
call_metadata, error_augmenter = database.with_error_augmentation(
304304
nth_request, 1, metadata, span
305305
)
306306
with error_augmenter:
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/snapshot.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/snapshot.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def execute_sql(
322322
raise ValueError("Transaction has not begun.")
323323
if params is not None:
324324
params_pb = Struct(
325-
fields={key: _make_value_pb(value) for (key, value) in params.items()}
325+
fields={key: _make_value_pb(value) for key, value in params.items()}
326326
)
327327
else:
328328
params_pb = {}
@@ -513,7 +513,7 @@ def partition_query(
513513
raise ValueError("Cannot partition a single-use transaction.")
514514
if params is not None:
515515
params_pb = Struct(
516-
fields={key: _make_value_pb(value) for (key, value) in params.items()}
516+
fields={key: _make_value_pb(value) for key, value in params.items()}
517517
)
518518
else:
519519
params_pb = Struct()
@@ -614,7 +614,7 @@ def wrapped_method():
614614
begin_transaction_request = BeginTransactionRequest(
615615
**begin_request_kwargs
616616
)
617-
(call_metadata, error_augmenter) = database.with_error_augmentation(
617+
call_metadata, error_augmenter = database.with_error_augmentation(
618618
nth_request, attempt.increment(), metadata, span
619619
)
620620
begin_transaction_method = functools.partial(
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/streamed.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/streamed.py
+4-4Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _consume_next(self):
147147

148148
def __iter__(self):
149149
while True:
150-
(iter_rows, self._rows[:]) = (self._rows[:], ())
150+
iter_rows, self._rows[:] = (self._rows[:], ())
151151
while iter_rows:
152152
yield iter_rows.pop(0)
153153
if self._done:
@@ -230,7 +230,7 @@ def to_dict_list(self):
230230
rows.append(
231231
{
232232
column: value
233-
for (column, value) in zip(
233+
for column, value in zip(
234234
[column.name for column in self._metadata.row_type.fields], row
235235
)
236236
}
@@ -291,7 +291,7 @@ def _merge_array(lhs, rhs, type_):
291291
if element_type.code in _UNMERGEABLE_TYPES:
292292
lhs.list_value.values.extend(rhs.list_value.values)
293293
return lhs
294-
(lhs, rhs) = (list(lhs.list_value.values), list(rhs.list_value.values))
294+
lhs, rhs = (list(lhs.list_value.values), list(rhs.list_value.values))
295295
if not len(lhs) or not len(rhs):
296296
return Value(list_value=ListValue(values=lhs + rhs))
297297
first = rhs.pop(0)
@@ -316,7 +316,7 @@ def _merge_array(lhs, rhs, type_):
316316
def _merge_struct(lhs, rhs, type_):
317317
"""Helper for '_merge_by_type'."""
318318
fields = type_.struct_type.fields
319-
(lhs, rhs) = (list(lhs.list_value.values), list(rhs.list_value.values))
319+
lhs, rhs = (list(lhs.list_value.values), list(rhs.list_value.values))
320320
if not len(lhs) or not len(rhs):
321321
return Value(list_value=ListValue(values=lhs + rhs))
322322
candidate_type = fields[len(lhs) - 1].type_
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/transaction.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/transaction.py
+9-9Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def rollback(self) -> None:
162162

163163
def wrapped_method(*args, **kwargs):
164164
attempt.increment()
165-
(call_metadata, error_augmenter) = database.with_error_augmentation(
165+
call_metadata, error_augmenter = database.with_error_augmentation(
166166
nth_request, attempt.value, metadata, span
167167
)
168168
rollback_method = functools.partial(
@@ -269,7 +269,7 @@ def wrapped_method(*args, **kwargs):
269269
is_multiplexed = getattr(self._session, "is_multiplexed", False)
270270
if is_multiplexed and self._precommit_token is not None:
271271
commit_request_args["precommit_token"] = self._precommit_token
272-
(call_metadata, error_augmenter) = database.with_error_augmentation(
272+
call_metadata, error_augmenter = database.with_error_augmentation(
273273
nth_request, attempt.value, metadata, span
274274
)
275275
commit_method = functools.partial(
@@ -300,7 +300,7 @@ def before_next_retry(nth_retry, delay_in_seconds):
300300
if commit_response_pb._pb.HasField("precommit_token"):
301301
add_span_event(span, commit_retry_event_name)
302302
nth_request = database._next_nth_request
303-
(call_metadata, error_augmenter) = database.with_error_augmentation(
303+
call_metadata, error_augmenter = database.with_error_augmentation(
304304
nth_request, 1, metadata, span
305305
)
306306
with error_augmenter:
@@ -338,7 +338,7 @@ def _make_params_pb(params, param_types):
338338
If ``params`` is None but ``param_types`` is not None."""
339339
if params:
340340
return Struct(
341-
fields={key: _make_value_pb(value) for (key, value) in params.items()}
341+
fields={key: _make_value_pb(value) for key, value in params.items()}
342342
)
343343
return {}
344344

@@ -417,7 +417,7 @@ def execute_update(
417417
metadata.append(
418418
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
419419
)
420-
(seqno, self._execute_sql_request_count) = (
420+
seqno, self._execute_sql_request_count = (
421421
self._execute_sql_request_count,
422422
self._execute_sql_request_count + 1,
423423
)
@@ -454,7 +454,7 @@ def execute_update(
454454

455455
def wrapped_method(*args, **kwargs):
456456
attempt.increment()
457-
(call_metadata, error_augmenter) = database.with_error_augmentation(
457+
call_metadata, error_augmenter = database.with_error_augmentation(
458458
nth_request, attempt.value, metadata
459459
)
460460
execute_sql_method = functools.partial(
@@ -544,7 +544,7 @@ def batch_update(
544544
if isinstance(statement, str):
545545
parsed.append(ExecuteBatchDmlRequest.Statement(sql=statement))
546546
else:
547-
(dml, params, param_types) = statement
547+
dml, params, param_types = statement
548548
params_pb = self._make_params_pb(params, param_types)
549549
parsed.append(
550550
ExecuteBatchDmlRequest.Statement(
@@ -556,7 +556,7 @@ def batch_update(
556556
metadata.append(
557557
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
558558
)
559-
(seqno, self._execute_sql_request_count) = (
559+
seqno, self._execute_sql_request_count = (
560560
self._execute_sql_request_count,
561561
self._execute_sql_request_count + 1,
562562
)
@@ -590,7 +590,7 @@ def batch_update(
590590

591591
def wrapped_method(*args, **kwargs):
592592
attempt.increment()
593-
(call_metadata, error_augmenter) = database.with_error_augmentation(
593+
call_metadata, error_augmenter = database.with_error_augmentation(
594594
nth_request, attempt.value, metadata
595595
)
596596
execute_batch_dml_method = functools.partial(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.