Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 56cbea8

Browse filesBrowse files
authored
fix(rab): run async background boundary refresh on detached session (#17441)
When AuthorizedSession.request() makes an API call, it runs inside a temporary aiohttp ClientSession block. If our background Regional Access Boundary (RAB) refresh worker naively shares this exact same session, a fast primary call (like an instant 401/403 or a quick CRM check) will exit its block and close the active socket mid-flight. This causes the background worker to silently fail with "RuntimeError: Session is closed" and forces the RAB manager into a 15-minute cooldown. This commit resolves the race condition and ensures safe connection lifecycle management: - Shifted the cloning block to run synchronously inside start_refresh, capturing a fresh, independent ClientSession before the foreground thread can close the source transport. - Added a _clone() method to async Request adapters (both modern and legacy) to copy proxy settings and trace configurations while enforcing connector limits. - Prevented resource leaks on task creation failures by capturing exceptions in start_refresh and closing the cloned session synchronously. - Refactored the close wrapper to inspect and await generic awaitables (such as asyncio.Future) returned by custom or third-party transports. - Aligned exception behaviors by raising a wrapped TransportError directly when calling a closed instance of the legacy aiohttp_requests adapter. - Ensured the cloned transport is cleanly closed in a finally block after the background lookup settles.
1 parent b50cf1a commit 56cbea8
Copy full SHA for 56cbea8

9 files changed

+1,025-4Lines changed: 1025 additions & 4 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎packages/google-auth/google/auth/_regional_access_boundary_utils.py‎

Copy file name to clipboardExpand all lines: packages/google-auth/google/auth/_regional_access_boundary_utils.py
+85-3Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from google.auth import _helpers
2828
from google.auth import environment_vars
2929

30-
if TYPE_CHECKING:
30+
if TYPE_CHECKING: # pragma: NO COVER
3131
import google.auth.credentials
3232
import google.auth.transport
3333

@@ -455,6 +455,61 @@ def start_refresh(self, credentials, request, rab_manager):
455455
self._worker.start()
456456

457457

458+
def _prepare_async_lookup_callable(request):
459+
"""Unwraps a request callable, clones the transport, and returns the new callable.
460+
461+
Args:
462+
request: The original request callable (e.g. functools.partial or raw Request).
463+
464+
Returns:
465+
Tuple[Callable, Any, bool]: A tuple containing the new lookup callable, the
466+
underlying request object, and a boolean indicating if it was cloned.
467+
"""
468+
is_partial = isinstance(request, functools.partial)
469+
base_callable = request.func if is_partial else request
470+
471+
if not hasattr(base_callable, "_clone"):
472+
return request, base_callable, False
473+
474+
cloned_callable = base_callable._clone()
475+
is_cloned = cloned_callable is not base_callable
476+
477+
if is_partial:
478+
new_request = functools.partial(
479+
cloned_callable, *request.args, **request.keywords
480+
)
481+
else:
482+
new_request = cloned_callable
483+
484+
return new_request, cloned_callable, is_cloned
485+
486+
487+
async def _close_cloned_request(lookup_request, is_cloned):
488+
"""Safely closes the underlying cloned request transport, if applicable.
489+
490+
Args:
491+
lookup_request (Any): The request object/transport to close.
492+
is_cloned (bool): Whether the request was actually cloned.
493+
"""
494+
if not is_cloned or not hasattr(lookup_request, "close"):
495+
return
496+
497+
is_async = False
498+
try:
499+
maybe_coro = lookup_request.close()
500+
if is_async := inspect.isawaitable(maybe_coro):
501+
await maybe_coro
502+
except Exception as e:
503+
if _helpers.is_logging_enabled(_LOGGER):
504+
adapter_type = " asynchronous " if is_async else " "
505+
_LOGGER.warning(
506+
"Failed to cleanly close cloned%srequest transport: %s",
507+
adapter_type,
508+
e,
509+
exc_info=True,
510+
)
511+
512+
458513
class _AsyncRegionalAccessBoundaryRefreshManager(object):
459514
"""Manages a task for background refreshing of the Regional Access Boundary in async flows."""
460515

@@ -491,11 +546,28 @@ def start_refresh(self, credentials, request, rab_manager):
491546
# A refresh is already in progress.
492547
return
493548

549+
try:
550+
(
551+
lookup_callable,
552+
lookup_request,
553+
is_cloned,
554+
) = _prepare_async_lookup_callable(request)
555+
except Exception as e:
556+
if _helpers.is_logging_enabled(_LOGGER):
557+
_LOGGER.warning(
558+
"Synchronous cloning of request for Regional Access Boundary lookup failed: %s",
559+
e,
560+
exc_info=True,
561+
)
562+
rab_manager.process_regional_access_boundary_info(None)
563+
return
564+
494565
async def _worker():
495566
try:
496-
# credentials._lookup_regional_access_boundary should be async in the async creds class
497567
regional_access_boundary_info = (
498-
await credentials._lookup_regional_access_boundary(request)
568+
await credentials._lookup_regional_access_boundary(
569+
lookup_callable
570+
)
499571
)
500572
except Exception as e:
501573
if _helpers.is_logging_enabled(_LOGGER):
@@ -505,6 +577,8 @@ async def _worker():
505577
exc_info=True,
506578
)
507579
regional_access_boundary_info = None
580+
finally:
581+
await _close_cloned_request(lookup_request, is_cloned)
508582

509583
rab_manager.process_regional_access_boundary_info(
510584
regional_access_boundary_info
@@ -514,7 +588,15 @@ async def _worker():
514588
try:
515589
self._worker_task = asyncio.create_task(coro)
516590
except Exception:
591+
# Clean up cloned request if task creation fails
517592
coro.close()
593+
try:
594+
asyncio.get_running_loop().create_task(
595+
_close_cloned_request(lookup_request, is_cloned)
596+
)
597+
except RuntimeError:
598+
pass
599+
rab_manager.process_regional_access_boundary_info(None)
518600
raise
519601

520602

Collapse file

‎packages/google-auth/google/auth/aio/transport/__init__.py‎

Copy file name to clipboardExpand all lines: packages/google-auth/google/auth/aio/transport/__init__.py
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,13 @@ async def close(self) -> None:
142142
Close the underlying session.
143143
"""
144144
raise NotImplementedError("close must be implemented.")
145+
146+
def _clone(self) -> "Request":
147+
"""Creates a copy of this request adapter.
148+
149+
The base implementation returns `self` (an identical shared instance).
150+
Transport adapters that maintain internal connection pools or stateful
151+
sessions must override this method to return an independent, detached
152+
adapter instance.
153+
"""
154+
return self
Collapse file

‎packages/google-auth/google/auth/aio/transport/aiohttp.py‎

Copy file name to clipboardExpand all lines: packages/google-auth/google/auth/aio/transport/aiohttp.py
+81-1Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
else:
3737
try:
3838
from aiohttp import ClientTimeout
39-
except (ImportError, AttributeError):
39+
except (ImportError, AttributeError): # pragma: NO COVER
4040
ClientTimeout = None
4141

4242
_LOGGER = logging.getLogger(__name__)
@@ -203,3 +203,83 @@ async def close(self) -> None:
203203
if not self._closed and self._session:
204204
await self._session.close()
205205
self._closed = True
206+
207+
def _clone(self) -> "Request":
208+
"""Creates an independent copy of this request adapter.
209+
210+
Clones the connection settings, trace configurations, and session defaults
211+
(headers, cookies, basic auth, and timeouts).
212+
213+
Only standard `aiohttp.TCPConnector` and `aiohttp.UnixConnector` connectors
214+
are supported. The DNS resolver is not copied to avoid closing shared resolver
215+
resources.
216+
217+
Returns:
218+
google.auth.aio.transport.aiohttp.Request: A new request adapter.
219+
220+
Raises:
221+
google.auth.exceptions.TransportError: If the transport is closed, or if the
222+
session uses an unsupported connector.
223+
"""
224+
if self._closed:
225+
raise exceptions.TransportError("Cannot clone a closed transport.")
226+
227+
if not self._session:
228+
new_session = aiohttp.ClientSession(
229+
auto_decompress=False,
230+
trust_env=True,
231+
)
232+
return Request(session=new_session)
233+
234+
session_kwargs: dict = {
235+
"auto_decompress": False,
236+
"trust_env": getattr(self._session, "_trust_env", True),
237+
}
238+
239+
# Copy underlying connection pool settings (SSL context, IP bindings, limits).
240+
orig_connector = getattr(self._session, "_connector", None)
241+
if orig_connector and not orig_connector.closed:
242+
if isinstance(orig_connector, aiohttp.TCPConnector):
243+
# We explicitly do not copy the resolver. The connector
244+
# owns the resolver, and closing the cloned session would
245+
# close the shared resolver, breaking the original session.
246+
session_kwargs["connector"] = aiohttp.TCPConnector(
247+
ssl=getattr(orig_connector, "_ssl", None), # type: ignore
248+
limit=getattr(orig_connector, "_limit", 100),
249+
limit_per_host=getattr(orig_connector, "_limit_per_host", 0),
250+
force_close=getattr(orig_connector, "_force_close", False),
251+
local_addr=getattr(orig_connector, "_local_addr", None),
252+
)
253+
elif getattr(aiohttp, "UnixConnector", None) and isinstance(
254+
orig_connector, getattr(aiohttp, "UnixConnector")
255+
):
256+
path = getattr(orig_connector, "_path", None)
257+
if path:
258+
session_kwargs["connector"] = aiohttp.UnixConnector(
259+
path=path,
260+
limit=getattr(orig_connector, "_limit", 100),
261+
force_close=getattr(orig_connector, "_force_close", False),
262+
)
263+
else:
264+
raise exceptions.TransportError(
265+
f"Unsupported connector type for cloning: {type(orig_connector)}"
266+
)
267+
268+
# Preserve distributed tracing configurations.
269+
trace_configs = getattr(self._session, "_trace_configs", None)
270+
if trace_configs:
271+
session_kwargs["trace_configs"] = list(trace_configs)
272+
273+
# Copy session-level defaults (headers, cookies, auth, timeout).
274+
for attr_name, kwarg_name in [
275+
("_default_headers", "headers"),
276+
("_cookie_jar", "cookie_jar"),
277+
("_default_auth", "auth"),
278+
("_timeout", "timeout"),
279+
("_json_serialize", "json_serialize"),
280+
]:
281+
val = getattr(self._session, attr_name, None)
282+
if val is not None:
283+
session_kwargs[kwarg_name] = val
284+
285+
return Request(session=aiohttp.ClientSession(**session_kwargs)) # type: ignore
Collapse file

‎packages/google-auth/google/auth/transport/_aiohttp_requests.py‎

Copy file name to clipboardExpand all lines: packages/google-auth/google/auth/transport/_aiohttp_requests.py
+90Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(self, session=None):
148148
"Client sessions with auto_decompress=True are not supported."
149149
)
150150
self.session = session
151+
self._closed = False
151152

152153
async def __call__(
153154
self,
@@ -183,6 +184,9 @@ async def __call__(
183184
"""
184185

185186
try:
187+
if getattr(self, "_closed", False):
188+
raise exceptions.TransportError("session is closed.")
189+
186190
if self.session is None: # pragma: NO COVER
187191
self.session = aiohttp.ClientSession(
188192
auto_decompress=False
@@ -202,6 +206,92 @@ async def __call__(
202206
new_exc = exceptions.TransportError(caught_exc)
203207
raise new_exc from caught_exc
204208

209+
def _clone(self):
210+
"""Creates an independent copy of this request adapter.
211+
212+
Clones the connection settings, trace configurations, and session defaults
213+
(headers, cookies, basic auth, and timeouts).
214+
215+
Only standard `aiohttp.TCPConnector` and `aiohttp.UnixConnector` connectors
216+
are supported. The DNS resolver is not copied to avoid closing shared resolver
217+
resources.
218+
219+
Returns:
220+
google.auth.transport._aiohttp_requests.Request: A new request adapter.
221+
222+
Raises:
223+
google.auth.exceptions.TransportError: If the transport is closed, or if the
224+
session uses an unsupported connector.
225+
"""
226+
if getattr(self, "_closed", False):
227+
raise exceptions.TransportError("Cannot clone a closed transport.")
228+
229+
if not self.session:
230+
new_session = aiohttp.ClientSession(
231+
auto_decompress=False,
232+
trust_env=True,
233+
)
234+
return Request(session=new_session)
235+
236+
session_kwargs: dict = {
237+
"auto_decompress": False,
238+
"trust_env": getattr(self.session, "_trust_env", True),
239+
}
240+
241+
# Copy underlying connection pool settings (SSL context, IP bindings, limits).
242+
orig_connector = getattr(self.session, "_connector", None)
243+
if orig_connector and not getattr(orig_connector, "closed", True):
244+
if isinstance(orig_connector, aiohttp.TCPConnector):
245+
# We explicitly do not copy the resolver. The connector
246+
# owns the resolver, and closing the cloned session would
247+
# close the shared resolver, breaking the original session.
248+
session_kwargs["connector"] = aiohttp.TCPConnector(
249+
ssl=getattr(orig_connector, "_ssl", None), # type: ignore
250+
limit=getattr(orig_connector, "_limit", 100),
251+
limit_per_host=getattr(orig_connector, "_limit_per_host", 0),
252+
force_close=getattr(orig_connector, "_force_close", False),
253+
local_addr=getattr(orig_connector, "_local_addr", None),
254+
)
255+
elif getattr(aiohttp, "UnixConnector", None) and isinstance(
256+
orig_connector, getattr(aiohttp, "UnixConnector")
257+
):
258+
path = getattr(orig_connector, "_path", None)
259+
if path:
260+
session_kwargs["connector"] = aiohttp.UnixConnector(
261+
path=path,
262+
limit=getattr(orig_connector, "_limit", 100),
263+
force_close=getattr(orig_connector, "_force_close", False),
264+
)
265+
else:
266+
raise exceptions.TransportError(
267+
f"Unsupported connector type for cloning: {type(orig_connector)}"
268+
)
269+
270+
# Preserve distributed tracing configurations.
271+
trace_configs = getattr(self.session, "_trace_configs", None)
272+
if trace_configs:
273+
session_kwargs["trace_configs"] = list(trace_configs)
274+
275+
# Copy session-level defaults (headers, cookies, auth, timeout).
276+
for attr_name, kwarg_name in [
277+
("_default_headers", "headers"),
278+
("_cookie_jar", "cookie_jar"),
279+
("_default_auth", "auth"),
280+
("_timeout", "timeout"),
281+
("_json_serialize", "json_serialize"),
282+
]:
283+
val = getattr(self.session, attr_name, None)
284+
if val is not None:
285+
session_kwargs[kwarg_name] = val
286+
287+
return Request(session=aiohttp.ClientSession(**session_kwargs)) # type: ignore
288+
289+
async def close(self):
290+
"""Cleanly release the underlying aiohttp ClientSession resources."""
291+
if not getattr(self, "_closed", False) and self.session:
292+
await self.session.close()
293+
self._closed = True
294+
205295

206296
class AuthorizedSession(aiohttp.ClientSession):
207297
"""This is an async implementation of the Authorized Session class. We utilize an

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.