|
import asyncio |
|
import socket |
|
import weakref |
|
from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union |
|
|
|
from .abc import AbstractResolver, ResolveResult |
|
|
|
__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver") |
|
|
|
|
|
try: |
|
import aiodns |
|
|
|
aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo") |
|
except ImportError: |
|
aiodns = None |
|
aiodns_default = False |
|
|
|
|
|
_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV |
|
_NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV |
|
_AI_ADDRCONFIG = socket.AI_ADDRCONFIG |
|
if hasattr(socket, "AI_MASK"): |
|
_AI_ADDRCONFIG &= socket.AI_MASK |
|
|
|
|
|
class ThreadedResolver(AbstractResolver): |
|
"""Threaded resolver. |
|
|
|
Uses an Executor for synchronous getaddrinfo() calls. |
|
concurrent.futures.ThreadPoolExecutor is used by default. |
|
""" |
|
|
|
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: |
|
self._loop = loop or asyncio.get_running_loop() |
|
|
|
async def resolve( |
|
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET |
|
) -> List[ResolveResult]: |
|
infos = await self._loop.getaddrinfo( |
|
host, |
|
port, |
|
type=socket.SOCK_STREAM, |
|
family=family, |
|
flags=_AI_ADDRCONFIG, |
|
) |
|
|
|
hosts: List[ResolveResult] = [] |
|
for family, _, proto, _, address in infos: |
|
if family == socket.AF_INET6: |
|
if len(address) < 3: |
|
|
|
|
|
continue |
|
if address[3]: |
|
|
|
|
|
|
|
resolved_host, _port = await self._loop.getnameinfo( |
|
address, _NAME_SOCKET_FLAGS |
|
) |
|
port = int(_port) |
|
else: |
|
resolved_host, port = address[:2] |
|
else: |
|
assert family == socket.AF_INET |
|
resolved_host, port = address |
|
hosts.append( |
|
ResolveResult( |
|
hostname=host, |
|
host=resolved_host, |
|
port=port, |
|
family=family, |
|
proto=proto, |
|
flags=_NUMERIC_SOCKET_FLAGS, |
|
) |
|
) |
|
|
|
return hosts |
|
|
|
async def close(self) -> None: |
|
pass |
|
|
|
|
|
class AsyncResolver(AbstractResolver): |
|
"""Use the `aiodns` package to make asynchronous DNS lookups""" |
|
|
|
def __init__( |
|
self, |
|
loop: Optional[asyncio.AbstractEventLoop] = None, |
|
*args: Any, |
|
**kwargs: Any, |
|
) -> None: |
|
if aiodns is None: |
|
raise RuntimeError("Resolver requires aiodns library") |
|
|
|
self._loop = loop or asyncio.get_running_loop() |
|
self._manager: Optional[_DNSResolverManager] = None |
|
|
|
|
|
|
|
if args or kwargs: |
|
self._resolver = aiodns.DNSResolver(*args, **kwargs) |
|
return |
|
|
|
self._manager = _DNSResolverManager() |
|
self._resolver = self._manager.get_resolver(self, self._loop) |
|
|
|
if not hasattr(self._resolver, "gethostbyname"): |
|
|
|
self.resolve = self._resolve_with_query |
|
|
|
async def resolve( |
|
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET |
|
) -> List[ResolveResult]: |
|
try: |
|
resp = await self._resolver.getaddrinfo( |
|
host, |
|
port=port, |
|
type=socket.SOCK_STREAM, |
|
family=family, |
|
flags=_AI_ADDRCONFIG, |
|
) |
|
except aiodns.error.DNSError as exc: |
|
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" |
|
raise OSError(None, msg) from exc |
|
hosts: List[ResolveResult] = [] |
|
for node in resp.nodes: |
|
address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr |
|
family = node.family |
|
if family == socket.AF_INET6: |
|
if len(address) > 3 and address[3]: |
|
|
|
|
|
|
|
result = await self._resolver.getnameinfo( |
|
(address[0].decode("ascii"), *address[1:]), |
|
_NAME_SOCKET_FLAGS, |
|
) |
|
resolved_host = result.node |
|
else: |
|
resolved_host = address[0].decode("ascii") |
|
port = address[1] |
|
else: |
|
assert family == socket.AF_INET |
|
resolved_host = address[0].decode("ascii") |
|
port = address[1] |
|
hosts.append( |
|
ResolveResult( |
|
hostname=host, |
|
host=resolved_host, |
|
port=port, |
|
family=family, |
|
proto=0, |
|
flags=_NUMERIC_SOCKET_FLAGS, |
|
) |
|
) |
|
|
|
if not hosts: |
|
raise OSError(None, "DNS lookup failed") |
|
|
|
return hosts |
|
|
|
async def _resolve_with_query( |
|
self, host: str, port: int = 0, family: int = socket.AF_INET |
|
) -> List[Dict[str, Any]]: |
|
qtype: Final = "AAAA" if family == socket.AF_INET6 else "A" |
|
|
|
try: |
|
resp = await self._resolver.query(host, qtype) |
|
except aiodns.error.DNSError as exc: |
|
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" |
|
raise OSError(None, msg) from exc |
|
|
|
hosts = [] |
|
for rr in resp: |
|
hosts.append( |
|
{ |
|
"hostname": host, |
|
"host": rr.host, |
|
"port": port, |
|
"family": family, |
|
"proto": 0, |
|
"flags": socket.AI_NUMERICHOST, |
|
} |
|
) |
|
|
|
if not hosts: |
|
raise OSError(None, "DNS lookup failed") |
|
|
|
return hosts |
|
|
|
async def close(self) -> None: |
|
if self._manager: |
|
|
|
self._manager.release_resolver(self, self._loop) |
|
self._manager = None |
|
self._resolver = None |
|
return |
|
|
|
if self._resolver is not None: |
|
self._resolver.cancel() |
|
self._resolver = None |
|
|
|
|
|
class _DNSResolverManager: |
|
"""Manager for aiodns.DNSResolver objects. |
|
|
|
This class manages shared aiodns.DNSResolver instances |
|
with no custom arguments across different event loops. |
|
""" |
|
|
|
_instance: Optional["_DNSResolverManager"] = None |
|
|
|
def __new__(cls) -> "_DNSResolverManager": |
|
if cls._instance is None: |
|
cls._instance = super().__new__(cls) |
|
cls._instance._init() |
|
return cls._instance |
|
|
|
def _init(self) -> None: |
|
|
|
self._loop_data: weakref.WeakKeyDictionary[ |
|
asyncio.AbstractEventLoop, |
|
tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]], |
|
] = weakref.WeakKeyDictionary() |
|
|
|
def get_resolver( |
|
self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
|
) -> "aiodns.DNSResolver": |
|
"""Get or create the shared aiodns.DNSResolver instance for a specific event loop. |
|
|
|
Args: |
|
client: The AsyncResolver instance requesting the resolver. |
|
This is required to track resolver usage. |
|
loop: The event loop to use for the resolver. |
|
""" |
|
|
|
if loop not in self._loop_data: |
|
resolver = aiodns.DNSResolver(loop=loop) |
|
client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet() |
|
self._loop_data[loop] = (resolver, client_set) |
|
else: |
|
|
|
resolver, client_set = self._loop_data[loop] |
|
|
|
|
|
client_set.add(client) |
|
return resolver |
|
|
|
def release_resolver( |
|
self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
|
) -> None: |
|
"""Release the resolver for an AsyncResolver client when it's closed. |
|
|
|
Args: |
|
client: The AsyncResolver instance to release. |
|
loop: The event loop the resolver was using. |
|
""" |
|
|
|
current_loop_data = self._loop_data.get(loop) |
|
if current_loop_data is None: |
|
return |
|
resolver, client_set = current_loop_data |
|
client_set.discard(client) |
|
|
|
if not client_set: |
|
if resolver is not None: |
|
resolver.cancel() |
|
del self._loop_data[loop] |
|
|
|
|
|
_DefaultType = Type[Union[AsyncResolver, ThreadedResolver]] |
|
DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver |
|
|