20
20
USA
21
21
"""
22
22
23
+ from __future__ import annotations
24
+
23
25
from heapq import heapify , heappop , heappush
24
- from typing import Dict , Iterable , List , Optional , Set , Tuple , Union , cast
26
+ from typing import Dict , Iterable , Union , cast
25
27
26
28
from ._dns import (
27
29
DNSAddress ,
@@ -66,8 +68,8 @@ class DNSCache:
66
68
67
69
def __init__ (self ) -> None :
68
70
self .cache : _DNSRecordCacheType = {}
69
- self ._expire_heap : List [ Tuple [float , DNSRecord ]] = []
70
- self ._expirations : Dict [DNSRecord , float ] = {}
71
+ self ._expire_heap : list [ tuple [float , DNSRecord ]] = []
72
+ self ._expirations : dict [DNSRecord , float ] = {}
71
73
self .service_cache : _DNSRecordCacheType = {}
72
74
73
75
# Functions prefixed with async_ are NOT threadsafe and must
@@ -135,7 +137,7 @@ def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
135
137
for entry in entries :
136
138
self ._async_remove (entry )
137
139
138
- def async_expire (self , now : _float ) -> List [DNSRecord ]:
140
+ def async_expire (self , now : _float ) -> list [DNSRecord ]:
139
141
"""Purge expired entries from the cache.
140
142
141
143
This function must be run in from event loop.
@@ -145,7 +147,7 @@ def async_expire(self, now: _float) -> List[DNSRecord]:
145
147
if not (expire_heap_len := len (self ._expire_heap )):
146
148
return []
147
149
148
- expired : List [DNSRecord ] = []
150
+ expired : list [DNSRecord ] = []
149
151
# Find any expired records and add them to the to-delete list
150
152
while self ._expire_heap :
151
153
when_record = self ._expire_heap [0 ]
@@ -182,7 +184,7 @@ def async_expire(self, now: _float) -> List[DNSRecord]:
182
184
self .async_remove_records (expired )
183
185
return expired
184
186
185
- def async_get_unique (self , entry : _UniqueRecordsType ) -> Optional [ DNSRecord ] :
187
+ def async_get_unique (self , entry : _UniqueRecordsType ) -> DNSRecord | None :
186
188
"""Gets a unique entry by key. Will return None if there is no
187
189
matching entry.
188
190
@@ -194,31 +196,31 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
194
196
return None
195
197
return store .get (entry )
196
198
197
- def async_all_by_details (self , name : _str , type_ : _int , class_ : _int ) -> List [DNSRecord ]:
199
+ def async_all_by_details (self , name : _str , type_ : _int , class_ : _int ) -> list [DNSRecord ]:
198
200
"""Gets all matching entries by details.
199
201
200
202
This function is not thread-safe and must be called from
201
203
the event loop.
202
204
"""
203
205
key = name .lower ()
204
206
records = self .cache .get (key )
205
- matches : List [DNSRecord ] = []
207
+ matches : list [DNSRecord ] = []
206
208
if records is None :
207
209
return matches
208
210
for record in records .values ():
209
211
if type_ == record .type and class_ == record .class_ :
210
212
matches .append (record )
211
213
return matches
212
214
213
- def async_entries_with_name (self , name : str ) -> List [DNSRecord ]:
215
+ def async_entries_with_name (self , name : str ) -> list [DNSRecord ]:
214
216
"""Returns a dict of entries whose key matches the name.
215
217
216
218
This function is not threadsafe and must be called from
217
219
the event loop.
218
220
"""
219
221
return self .entries_with_name (name )
220
222
221
- def async_entries_with_server (self , name : str ) -> List [DNSRecord ]:
223
+ def async_entries_with_server (self , name : str ) -> list [DNSRecord ]:
222
224
"""Returns a dict of entries whose key matches the server.
223
225
224
226
This function is not threadsafe and must be called from
@@ -230,7 +232,7 @@ def async_entries_with_server(self, name: str) -> List[DNSRecord]:
230
232
# event loop, however they all make copies so they significantly
231
233
# inefficient.
232
234
233
- def get (self , entry : DNSEntry ) -> Optional [ DNSRecord ] :
235
+ def get (self , entry : DNSEntry ) -> DNSRecord | None :
234
236
"""Gets an entry by key. Will return None if there is no
235
237
matching entry."""
236
238
if isinstance (entry , _UNIQUE_RECORD_TYPES ):
@@ -240,7 +242,7 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
240
242
return cached_entry
241
243
return None
242
244
243
- def get_by_details (self , name : str , type_ : _int , class_ : _int ) -> Optional [ DNSRecord ] :
245
+ def get_by_details (self , name : str , type_ : _int , class_ : _int ) -> DNSRecord | None :
244
246
"""Gets the first matching entry by details. Returns None if no entries match.
245
247
246
248
Calling this function is not recommended as it will only
@@ -261,27 +263,27 @@ def get_by_details(self, name: str, type_: _int, class_: _int) -> Optional[DNSRe
261
263
return cached_entry
262
264
return None
263
265
264
- def get_all_by_details (self , name : str , type_ : _int , class_ : _int ) -> List [DNSRecord ]:
266
+ def get_all_by_details (self , name : str , type_ : _int , class_ : _int ) -> list [DNSRecord ]:
265
267
"""Gets all matching entries by details."""
266
268
key = name .lower ()
267
269
records = self .cache .get (key )
268
270
if records is None :
269
271
return []
270
272
return [entry for entry in list (records .values ()) if type_ == entry .type and class_ == entry .class_ ]
271
273
272
- def entries_with_server (self , server : str ) -> List [DNSRecord ]:
274
+ def entries_with_server (self , server : str ) -> list [DNSRecord ]:
273
275
"""Returns a list of entries whose server matches the name."""
274
276
if entries := self .service_cache .get (server .lower ()):
275
277
return list (entries .values ())
276
278
return []
277
279
278
- def entries_with_name (self , name : str ) -> List [DNSRecord ]:
280
+ def entries_with_name (self , name : str ) -> list [DNSRecord ]:
279
281
"""Returns a list of entries whose key matches the name."""
280
282
if entries := self .cache .get (name .lower ()):
281
283
return list (entries .values ())
282
284
return []
283
285
284
- def current_entry_with_name_and_alias (self , name : str , alias : str ) -> Optional [ DNSRecord ] :
286
+ def current_entry_with_name_and_alias (self , name : str , alias : str ) -> DNSRecord | None :
285
287
now = current_time_millis ()
286
288
for record in reversed (self .entries_with_name (name )):
287
289
if (
@@ -292,13 +294,13 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D
292
294
return record
293
295
return None
294
296
295
- def names (self ) -> List [str ]:
297
+ def names (self ) -> list [str ]:
296
298
"""Return a copy of the list of current cache names."""
297
299
return list (self .cache )
298
300
299
301
def async_mark_unique_records_older_than_1s_to_expire (
300
302
self ,
301
- unique_types : Set [ Tuple [_str , _int , _int ]],
303
+ unique_types : set [ tuple [_str , _int , _int ]],
302
304
answers : Iterable [DNSRecord ],
303
305
now : _float ,
304
306
) -> None :
0 commit comments