27
27
import socket
28
28
import sys
29
29
import threading
30
+ from dataclasses import dataclass
30
31
from types import TracebackType # noqa # used in type hints
31
- from typing import Awaitable , Dict , List , Optional , Tuple , Type , Union , cast
32
+ from typing import Any , Awaitable , Dict , List , Optional , Tuple , Type , Union , cast
32
33
33
34
from ._cache import DNSCache
34
35
from ._dns import DNSQuestion , DNSQuestionType
105
106
_REGISTER_BROADCASTS = 3
106
107
107
108
109
+ @dataclass
110
+ class _WrappedTransport :
111
+ """A wrapper for transports."""
112
+
113
+ transport : asyncio .DatagramTransport
114
+ is_ipv6 : bool
115
+ socket : socket .socket
116
+ fileno : int
117
+ sock_name : Any
118
+
119
+
120
+ def _make_wrapped_transport (transport : asyncio .DatagramTransport ) -> _WrappedTransport :
121
+ """Make a wrapped transport."""
122
+ sock : socket .socket = transport .get_extra_info ('socket' )
123
+ return _WrappedTransport (
124
+ transport = transport ,
125
+ is_ipv6 = sock .family == socket .AF_INET6 ,
126
+ socket = sock ,
127
+ fileno = sock .fileno (),
128
+ sock_name = sock .getsockname (),
129
+ )
130
+
131
+
108
132
class AsyncEngine :
109
133
"""An engine wraps sockets in the event loop."""
110
134
@@ -117,8 +141,8 @@ def __init__(
117
141
self .loop : Optional [asyncio .AbstractEventLoop ] = None
118
142
self .zc = zeroconf
119
143
self .protocols : List [AsyncListener ] = []
120
- self .readers : List [asyncio . DatagramTransport ] = []
121
- self .senders : List [asyncio . DatagramTransport ] = []
144
+ self .readers : List [_WrappedTransport ] = []
145
+ self .senders : List [_WrappedTransport ] = []
122
146
self .running_event : Optional [asyncio .Event ] = None
123
147
self ._listen_socket = listen_socket
124
148
self ._respond_sockets = respond_sockets
@@ -158,9 +182,9 @@ async def _async_create_endpoints(self) -> None:
158
182
for s in reader_sockets :
159
183
transport , protocol = await loop .create_datagram_endpoint (lambda : AsyncListener (self .zc ), sock = s )
160
184
self .protocols .append (cast (AsyncListener , protocol ))
161
- self .readers .append (cast (asyncio .DatagramTransport , transport ))
185
+ self .readers .append (_make_wrapped_transport ( cast (asyncio .DatagramTransport , transport ) ))
162
186
if s in sender_sockets :
163
- self .senders .append (cast (asyncio .DatagramTransport , transport ))
187
+ self .senders .append (_make_wrapped_transport ( cast (asyncio .DatagramTransport , transport ) ))
164
188
165
189
def _async_cache_cleanup (self ) -> None :
166
190
"""Periodic cache cleanup."""
@@ -186,8 +210,8 @@ def _async_shutdown(self) -> None:
186
210
"""Shutdown transports and sockets."""
187
211
assert self .running_event is not None
188
212
self .running_event .clear ()
189
- for transport in itertools .chain (self .senders , self .readers ):
190
- transport .close ()
213
+ for wrapped_transport in itertools .chain (self .senders , self .readers ):
214
+ wrapped_transport . transport .close ()
191
215
192
216
def close (self ) -> None :
193
217
"""Close from sync context.
@@ -221,7 +245,7 @@ def __init__(self, zc: 'Zeroconf') -> None:
221
245
self .zc = zc
222
246
self .data : Optional [bytes ] = None
223
247
self .last_time : float = 0
224
- self .transport : Optional [asyncio . DatagramTransport ] = None
248
+ self .transport : Optional [_WrappedTransport ] = None
225
249
self .sock_description : Optional [str ] = None
226
250
self ._deferred : Dict [str , List [DNSIncoming ]] = {}
227
251
self ._timers : Dict [str , asyncio .TimerHandle ] = {}
@@ -309,7 +333,7 @@ def handle_query_or_defer(
309
333
msg : DNSIncoming ,
310
334
addr : str ,
311
335
port : int ,
312
- transport : asyncio . DatagramTransport ,
336
+ transport : _WrappedTransport ,
313
337
v6_flow_scope : Union [Tuple [()], Tuple [int , int ]] = (),
314
338
) -> None :
315
339
"""Deal with incoming query packets. Provides a response if
@@ -341,7 +365,7 @@ def _respond_query(
341
365
msg : Optional [DNSIncoming ],
342
366
addr : str ,
343
367
port : int ,
344
- transport : asyncio . DatagramTransport ,
368
+ transport : _WrappedTransport ,
345
369
v6_flow_scope : Union [Tuple [()], Tuple [int , int ]] = (),
346
370
) -> None :
347
371
"""Respond to a query and reassemble any truncated deferred packets."""
@@ -362,27 +386,25 @@ def error_received(self, exc: Exception) -> None:
362
386
self .log_exception_once (exc , msg_str , exc )
363
387
364
388
def connection_made (self , transport : asyncio .BaseTransport ) -> None :
365
- self .transport = cast (asyncio .DatagramTransport , transport )
366
- sock_name = self .transport .get_extra_info ('sockname' )
367
- sock_fileno = self .transport .get_extra_info ('socket' ).fileno ()
368
- self .sock_description = f"{ sock_fileno } ({ sock_name } )"
389
+ wrapped_transport = _make_wrapped_transport (cast (asyncio .DatagramTransport , transport ))
390
+ self .transport = wrapped_transport
391
+ self .sock_description = f"{ wrapped_transport .fileno } ({ wrapped_transport .sock_name } )"
369
392
370
393
def connection_lost (self , exc : Optional [Exception ]) -> None :
371
394
"""Handle connection lost."""
372
395
373
396
374
397
def async_send_with_transport (
375
398
log_debug : bool ,
376
- transport : asyncio . DatagramTransport ,
399
+ transport : _WrappedTransport ,
377
400
packet : bytes ,
378
401
packet_num : int ,
379
402
out : DNSOutgoing ,
380
403
addr : Optional [str ],
381
404
port : int ,
382
405
v6_flow_scope : Union [Tuple [()], Tuple [int , int ]] = (),
383
406
) -> None :
384
- s = transport .get_extra_info ('socket' )
385
- ipv6_socket = s .family == socket .AF_INET6
407
+ ipv6_socket = transport .is_ipv6
386
408
if addr is None :
387
409
real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR
388
410
else :
@@ -394,8 +416,8 @@ def async_send_with_transport(
394
416
'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...' ,
395
417
real_addr ,
396
418
port or _MDNS_PORT ,
397
- s .fileno () ,
398
- transport .get_extra_info ( 'sockname' ) ,
419
+ transport .fileno ,
420
+ transport .sock_name ,
399
421
len (packet ),
400
422
packet_num + 1 ,
401
423
out ,
@@ -404,9 +426,9 @@ def async_send_with_transport(
404
426
# Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6
405
427
# address tuple: https://docs.python.org/3.6/library/socket.html#socket-families
406
428
if ipv6_socket and not v6_flow_scope :
407
- _ , _ , sock_flowinfo , sock_scopeid = s . getsockname ()
429
+ _ , _ , sock_flowinfo , sock_scopeid = transport . sock_name
408
430
v6_flow_scope = (sock_flowinfo , sock_scopeid )
409
- transport .sendto (packet , (real_addr , port or _MDNS_PORT , * v6_flow_scope ))
431
+ transport .transport . sendto (packet , (real_addr , port or _MDNS_PORT , * v6_flow_scope ))
410
432
411
433
412
434
class Zeroconf (QuietLogger ):
@@ -832,7 +854,7 @@ def handle_assembled_query(
832
854
packets : List [DNSIncoming ],
833
855
addr : str ,
834
856
port : int ,
835
- transport : asyncio . DatagramTransport ,
857
+ transport : _WrappedTransport ,
836
858
v6_flow_scope : Union [Tuple [()], Tuple [int , int ]] = (),
837
859
) -> None :
838
860
"""Respond to a (re)assembled query.
@@ -870,7 +892,7 @@ def send(
870
892
addr : Optional [str ] = None ,
871
893
port : int = _MDNS_PORT ,
872
894
v6_flow_scope : Union [Tuple [()], Tuple [int , int ]] = (),
873
- transport : Optional [asyncio . DatagramTransport ] = None ,
895
+ transport : Optional [_WrappedTransport ] = None ,
874
896
) -> None :
875
897
"""Sends an outgoing packet threadsafe."""
876
898
assert self .loop is not None
@@ -882,7 +904,7 @@ def async_send(
882
904
addr : Optional [str ] = None ,
883
905
port : int = _MDNS_PORT ,
884
906
v6_flow_scope : Union [Tuple [()], Tuple [int , int ]] = (),
885
- transport : Optional [asyncio . DatagramTransport ] = None ,
907
+ transport : Optional [_WrappedTransport ] = None ,
886
908
) -> None :
887
909
"""Sends an outgoing packet."""
888
910
if self .done :
0 commit comments