jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import asyncio
import socket
import weakref
from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union
from .abc import AbstractResolver, ResolveResult
__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
try:
import aiodns
aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
except ImportError: # pragma: no cover
aiodns = None # type: ignore[assignment]
aiodns_default = False
_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
_NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
_AI_ADDRCONFIG = socket.AI_ADDRCONFIG
if hasattr(socket, "AI_MASK"):
_AI_ADDRCONFIG &= socket.AI_MASK
class ThreadedResolver(AbstractResolver):
"""Threaded resolver.
Uses an Executor for synchronous getaddrinfo() calls.
concurrent.futures.ThreadPoolExecutor is used by default.
"""
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_running_loop()
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> List[ResolveResult]:
infos = await self._loop.getaddrinfo(
host,
port,
type=socket.SOCK_STREAM,
family=family,
flags=_AI_ADDRCONFIG,
)
hosts: List[ResolveResult] = []
for family, _, proto, _, address in infos:
if family == socket.AF_INET6:
if len(address) < 3:
# IPv6 is not supported by Python build,
# or IPv6 is not enabled in the host
continue
if address[3]:
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
resolved_host, _port = await self._loop.getnameinfo(
address, _NAME_SOCKET_FLAGS
)
port = int(_port)
else:
resolved_host, port = address[:2]
else: # IPv4
assert family == socket.AF_INET
resolved_host, port = address # type: ignore[misc]
hosts.append(
ResolveResult(
hostname=host,
host=resolved_host,
port=port,
family=family,
proto=proto,
flags=_NUMERIC_SOCKET_FLAGS,
)
)
return hosts
async def close(self) -> None:
pass
class AsyncResolver(AbstractResolver):
"""Use the `aiodns` package to make asynchronous DNS lookups"""
def __init__(
self,
loop: Optional[asyncio.AbstractEventLoop] = None,
*args: Any,
**kwargs: Any,
) -> None:
if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")
self._loop = loop or asyncio.get_running_loop()
self._manager: Optional[_DNSResolverManager] = None
# If custom args are provided, create a dedicated resolver instance
# This means each AsyncResolver with custom args gets its own
# aiodns.DNSResolver instance
if args or kwargs:
self._resolver = aiodns.DNSResolver(*args, **kwargs)
return
# Use the shared resolver from the manager for default arguments
self._manager = _DNSResolverManager()
self._resolver = self._manager.get_resolver(self, self._loop)
if not hasattr(self._resolver, "gethostbyname"):
# aiodns 1.1 is not available, fallback to DNSResolver.query
self.resolve = self._resolve_with_query # type: ignore
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> List[ResolveResult]:
try:
resp = await self._resolver.getaddrinfo(
host,
port=port,
type=socket.SOCK_STREAM,
family=family,
flags=_AI_ADDRCONFIG,
)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(None, msg) from exc
hosts: List[ResolveResult] = []
for node in resp.nodes:
address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
family = node.family
if family == socket.AF_INET6:
if len(address) > 3 and address[3]:
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
result = await self._resolver.getnameinfo(
(address[0].decode("ascii"), *address[1:]),
_NAME_SOCKET_FLAGS,
)
resolved_host = result.node
else:
resolved_host = address[0].decode("ascii")
port = address[1]
else: # IPv4
assert family == socket.AF_INET
resolved_host = address[0].decode("ascii")
port = address[1]
hosts.append(
ResolveResult(
hostname=host,
host=resolved_host,
port=port,
family=family,
proto=0,
flags=_NUMERIC_SOCKET_FLAGS,
)
)
if not hosts:
raise OSError(None, "DNS lookup failed")
return hosts
async def _resolve_with_query(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> List[Dict[str, Any]]:
qtype: Final = "AAAA" if family == socket.AF_INET6 else "A"
try:
resp = await self._resolver.query(host, qtype)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(None, msg) from exc
hosts = []
for rr in resp:
hosts.append(
{
"hostname": host,
"host": rr.host,
"port": port,
"family": family,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
}
)
if not hosts:
raise OSError(None, "DNS lookup failed")
return hosts
async def close(self) -> None:
if self._manager:
# Release the resolver from the manager if using the shared resolver
self._manager.release_resolver(self, self._loop)
self._manager = None # Clear reference to manager
self._resolver = None # type: ignore[assignment] # Clear reference to resolver
return
# Otherwise cancel our dedicated resolver
if self._resolver is not None:
self._resolver.cancel()
self._resolver = None # type: ignore[assignment] # Clear reference
class _DNSResolverManager:
"""Manager for aiodns.DNSResolver objects.
This class manages shared aiodns.DNSResolver instances
with no custom arguments across different event loops.
"""
_instance: Optional["_DNSResolverManager"] = None
def __new__(cls) -> "_DNSResolverManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._init()
return cls._instance
def _init(self) -> None:
# Use WeakKeyDictionary to allow event loops to be garbage collected
self._loop_data: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop,
tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]],
] = weakref.WeakKeyDictionary()
def get_resolver(
self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
) -> "aiodns.DNSResolver":
"""Get or create the shared aiodns.DNSResolver instance for a specific event loop.
Args:
client: The AsyncResolver instance requesting the resolver.
This is required to track resolver usage.
loop: The event loop to use for the resolver.
"""
# Create a new resolver and client set for this loop if it doesn't exist
if loop not in self._loop_data:
resolver = aiodns.DNSResolver(loop=loop)
client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet()
self._loop_data[loop] = (resolver, client_set)
else:
# Get the existing resolver and client set
resolver, client_set = self._loop_data[loop]
# Register this client with the loop
client_set.add(client)
return resolver
def release_resolver(
self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
) -> None:
"""Release the resolver for an AsyncResolver client when it's closed.
Args:
client: The AsyncResolver instance to release.
loop: The event loop the resolver was using.
"""
# Remove client from its loop's tracking
current_loop_data = self._loop_data.get(loop)
if current_loop_data is None:
return
resolver, client_set = current_loop_data
client_set.discard(client)
# If no more clients for this loop, cancel and remove its resolver
if not client_set:
if resolver is not None:
resolver.cancel()
del self._loop_data[loop]
_DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver