|
import asyncio |
|
import ipaddress |
|
from pathlib import Path |
|
import socket |
|
from functools import lru_cache, wraps |
|
from typing import Any, Awaitable, Callable, Coroutine, Literal, T, Tuple |
|
|
|
import httpx |
|
|
|
def get_version(): |
|
version_file = Path(__file__).parent / 'version.txt' |
|
with open(version_file, 'r') as f: |
|
return f.read().strip() |
|
|
|
__version__ = get_version() |
|
|
|
def is_public_ip(ip: str) -> bool: |
|
try: |
|
ip_obj = ipaddress.ip_address(ip) |
|
return not ( |
|
ip_obj.is_private |
|
or ip_obj.is_loopback |
|
or ip_obj.is_link_local |
|
or ip_obj.is_multicast |
|
or ip_obj.is_reserved |
|
) |
|
except ValueError: |
|
return False |
|
|
|
|
|
def lru_cache_async(maxsize: int = 256): |
|
def decorator( |
|
async_func: Callable[..., Coroutine[Any, Any, T]], |
|
) -> Callable[..., Awaitable[T]]: |
|
@lru_cache(maxsize=maxsize) |
|
@wraps(async_func) |
|
def wrapper(*args: Any, **kwargs: Any) -> Awaitable[T]: |
|
return asyncio.create_task(async_func(*args, **kwargs)) |
|
|
|
return wrapper |
|
|
|
return decorator |
|
|
|
|
|
@lru_cache_async() |
|
async def async_resolve_hostname_google(hostname: str) -> list[str]: |
|
async with httpx.AsyncClient() as client: |
|
try: |
|
response_v4 = await client.get( |
|
f"https://dns.google/resolve?name={hostname}&type=A" |
|
) |
|
response_v6 = await client.get( |
|
f"https://dns.google/resolve?name={hostname}&type=AAAA" |
|
) |
|
|
|
ips = [] |
|
for response in [response_v4.json(), response_v6.json()]: |
|
ips.extend([answer["data"] for answer in response.get("Answer", [])]) |
|
return ips |
|
except Exception: |
|
return [] |
|
|
|
|
|
async def async_validate_url(hostname: str) -> str: |
|
try: |
|
loop = asyncio.get_event_loop() |
|
addrinfo = await loop.getaddrinfo(hostname, None) |
|
except socket.gaierror as e: |
|
raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e |
|
|
|
for family, _, _, _, sockaddr in addrinfo: |
|
ip_address = sockaddr[0] |
|
if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address): |
|
return ip_address |
|
|
|
for ip_address in await async_resolve_hostname_google(hostname): |
|
if is_public_ip(ip_address): |
|
return ip_address |
|
|
|
raise ValueError(f"Hostname {hostname} failed validation") |
|
|
|
|
|
class AsyncSecureTransport(httpx.AsyncHTTPTransport): |
|
def __init__(self, verified_ip: str): |
|
self.verified_ip = verified_ip |
|
super().__init__() |
|
|
|
async def handle_async_request( |
|
self, |
|
request: httpx.Request |
|
) -> Tuple[int, bytes, bytes, httpx.Headers]: |
|
original_url = request.url |
|
original_host = original_url.host |
|
new_url = original_url.copy_with(host=self.verified_ip) |
|
request.url = new_url |
|
request.headers['Host'] = original_host |
|
request.extensions = {"sni_hostname": original_host} |
|
return await super().handle_async_request(request) |
|
|
|
async def get( |
|
url: str, |
|
domain_whitelist: list[str] | None = None, |
|
_transport: httpx.AsyncBaseTransport | Literal[False] | None = None, |
|
**kwargs, |
|
) -> httpx.Response: |
|
""" |
|
This is the main function that should be used to make async HTTP GET requests. |
|
It will automatically use a secure transport for non-whitelisted domains. |
|
|
|
Parameters: |
|
- url (str): The URL to make a GET request to. |
|
- domain_whitelist (list[str] | None): A list of domains to whitelist, which will not use a secure transport. |
|
- _transport (httpx.AsyncBaseTransport | Literal[False] | None): A custom transport to use for the request. Takes precedence over domain_whitelist. Set to False to use no transport. |
|
- **kwargs: Additional keyword arguments to pass to the httpx.AsyncClient.get() function. |
|
""" |
|
parsed_url = httpx.URL(url) |
|
hostname = parsed_url.host |
|
if not hostname: |
|
raise ValueError(f"URL {url} does not have a valid hostname") |
|
if domain_whitelist is None: |
|
domain_whitelist = [] |
|
|
|
if _transport: |
|
transport = _transport |
|
elif _transport is False or hostname in domain_whitelist: |
|
transport = None |
|
else: |
|
verified_ip = await async_validate_url(hostname) |
|
transport = AsyncSecureTransport(verified_ip) |
|
|
|
async with httpx.AsyncClient(transport=transport) as client: |
|
return await client.get(url, follow_redirects=False, **kwargs) |
|
|