File size: 4,504 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import asyncio
import ipaddress
from pathlib import Path
import socket
from functools import lru_cache, wraps
from typing import Any, Awaitable, Callable, Coroutine, Literal, T, Tuple
import httpx
def get_version():
version_file = Path(__file__).parent / 'version.txt'
with open(version_file, 'r') as f:
return f.read().strip()
__version__ = get_version()
def is_public_ip(ip: str) -> bool:
try:
ip_obj = ipaddress.ip_address(ip)
return not (
ip_obj.is_private
or ip_obj.is_loopback
or ip_obj.is_link_local
or ip_obj.is_multicast
or ip_obj.is_reserved
)
except ValueError:
return False
def lru_cache_async(maxsize: int = 256):
def decorator(
async_func: Callable[..., Coroutine[Any, Any, T]],
) -> Callable[..., Awaitable[T]]:
@lru_cache(maxsize=maxsize)
@wraps(async_func)
def wrapper(*args: Any, **kwargs: Any) -> Awaitable[T]:
return asyncio.create_task(async_func(*args, **kwargs))
return wrapper
return decorator
@lru_cache_async()
async def async_resolve_hostname_google(hostname: str) -> list[str]:
async with httpx.AsyncClient() as client:
try:
response_v4 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=A"
)
response_v6 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=AAAA"
)
ips = []
for response in [response_v4.json(), response_v6.json()]:
ips.extend([answer["data"] for answer in response.get("Answer", [])])
return ips
except Exception:
return []
async def async_validate_url(hostname: str) -> str:
try:
loop = asyncio.get_event_loop()
addrinfo = await loop.getaddrinfo(hostname, None)
except socket.gaierror as e:
raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e
for family, _, _, _, sockaddr in addrinfo:
ip_address = sockaddr[0]
if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address):
return ip_address
for ip_address in await async_resolve_hostname_google(hostname):
if is_public_ip(ip_address):
return ip_address
raise ValueError(f"Hostname {hostname} failed validation")
class AsyncSecureTransport(httpx.AsyncHTTPTransport):
def __init__(self, verified_ip: str):
self.verified_ip = verified_ip
super().__init__()
async def handle_async_request(
self,
request: httpx.Request
) -> Tuple[int, bytes, bytes, httpx.Headers]:
original_url = request.url
original_host = original_url.host
new_url = original_url.copy_with(host=self.verified_ip)
request.url = new_url
request.headers['Host'] = original_host
request.extensions = {"sni_hostname": original_host}
return await super().handle_async_request(request)
async def get(
url: str,
domain_whitelist: list[str] | None = None,
_transport: httpx.AsyncBaseTransport | Literal[False] | None = None,
**kwargs,
) -> httpx.Response:
"""
This is the main function that should be used to make async HTTP GET requests.
It will automatically use a secure transport for non-whitelisted domains.
Parameters:
- url (str): The URL to make a GET request to.
- domain_whitelist (list[str] | None): A list of domains to whitelist, which will not use a secure transport.
- _transport (httpx.AsyncBaseTransport | Literal[False] | None): A custom transport to use for the request. Takes precedence over domain_whitelist. Set to False to use no transport.
- **kwargs: Additional keyword arguments to pass to the httpx.AsyncClient.get() function.
"""
parsed_url = httpx.URL(url)
hostname = parsed_url.host
if not hostname:
raise ValueError(f"URL {url} does not have a valid hostname")
if domain_whitelist is None:
domain_whitelist = []
if _transport:
transport = _transport
elif _transport is False or hostname in domain_whitelist:
transport = None
else:
verified_ip = await async_validate_url(hostname)
transport = AsyncSecureTransport(verified_ip)
async with httpx.AsyncClient(transport=transport) as client:
return await client.get(url, follow_redirects=False, **kwargs)
|