jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import asyncio
import ipaddress
from pathlib import Path
import socket
from functools import lru_cache, wraps
from typing import Any, Awaitable, Callable, Coroutine, Literal, T, Tuple
import httpx
def get_version():
version_file = Path(__file__).parent / 'version.txt'
with open(version_file, 'r') as f:
return f.read().strip()
__version__ = get_version()
def is_public_ip(ip: str) -> bool:
try:
ip_obj = ipaddress.ip_address(ip)
return not (
ip_obj.is_private
or ip_obj.is_loopback
or ip_obj.is_link_local
or ip_obj.is_multicast
or ip_obj.is_reserved
)
except ValueError:
return False
def lru_cache_async(maxsize: int = 256):
def decorator(
async_func: Callable[..., Coroutine[Any, Any, T]],
) -> Callable[..., Awaitable[T]]:
@lru_cache(maxsize=maxsize)
@wraps(async_func)
def wrapper(*args: Any, **kwargs: Any) -> Awaitable[T]:
return asyncio.create_task(async_func(*args, **kwargs))
return wrapper
return decorator
@lru_cache_async()
async def async_resolve_hostname_google(hostname: str) -> list[str]:
async with httpx.AsyncClient() as client:
try:
response_v4 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=A"
)
response_v6 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=AAAA"
)
ips = []
for response in [response_v4.json(), response_v6.json()]:
ips.extend([answer["data"] for answer in response.get("Answer", [])])
return ips
except Exception:
return []
async def async_validate_url(hostname: str) -> str:
try:
loop = asyncio.get_event_loop()
addrinfo = await loop.getaddrinfo(hostname, None)
except socket.gaierror as e:
raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e
for family, _, _, _, sockaddr in addrinfo:
ip_address = sockaddr[0]
if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address):
return ip_address
for ip_address in await async_resolve_hostname_google(hostname):
if is_public_ip(ip_address):
return ip_address
raise ValueError(f"Hostname {hostname} failed validation")
class AsyncSecureTransport(httpx.AsyncHTTPTransport):
def __init__(self, verified_ip: str):
self.verified_ip = verified_ip
super().__init__()
async def handle_async_request(
self,
request: httpx.Request
) -> Tuple[int, bytes, bytes, httpx.Headers]:
original_url = request.url
original_host = original_url.host
new_url = original_url.copy_with(host=self.verified_ip)
request.url = new_url
request.headers['Host'] = original_host
request.extensions = {"sni_hostname": original_host}
return await super().handle_async_request(request)
async def get(
url: str,
domain_whitelist: list[str] | None = None,
_transport: httpx.AsyncBaseTransport | Literal[False] | None = None,
**kwargs,
) -> httpx.Response:
"""
This is the main function that should be used to make async HTTP GET requests.
It will automatically use a secure transport for non-whitelisted domains.
Parameters:
- url (str): The URL to make a GET request to.
- domain_whitelist (list[str] | None): A list of domains to whitelist, which will not use a secure transport.
- _transport (httpx.AsyncBaseTransport | Literal[False] | None): A custom transport to use for the request. Takes precedence over domain_whitelist. Set to False to use no transport.
- **kwargs: Additional keyword arguments to pass to the httpx.AsyncClient.get() function.
"""
parsed_url = httpx.URL(url)
hostname = parsed_url.host
if not hostname:
raise ValueError(f"URL {url} does not have a valid hostname")
if domain_whitelist is None:
domain_whitelist = []
if _transport:
transport = _transport
elif _transport is False or hostname in domain_whitelist:
transport = None
else:
verified_ip = await async_validate_url(hostname)
transport = AsyncSecureTransport(verified_ip)
async with httpx.AsyncClient(transport=transport) as client:
return await client.get(url, follow_redirects=False, **kwargs)