|
from __future__ import annotations |
|
|
|
import asyncio |
|
import contextlib |
|
import logging |
|
import os |
|
import platform |
|
import signal |
|
import socket |
|
import sys |
|
import threading |
|
import time |
|
from collections.abc import Generator, Sequence |
|
from email.utils import formatdate |
|
from types import FrameType |
|
from typing import TYPE_CHECKING, Union |
|
|
|
import click |
|
|
|
from uvicorn.config import Config |
|
|
|
if TYPE_CHECKING: |
|
from uvicorn.protocols.http.h11_impl import H11Protocol |
|
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol |
|
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol |
|
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol |
|
|
|
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol] |
|
|
|
HANDLED_SIGNALS = ( |
|
signal.SIGINT, |
|
signal.SIGTERM, |
|
) |
|
if sys.platform == "win32": |
|
HANDLED_SIGNALS += (signal.SIGBREAK,) |
|
|
|
logger = logging.getLogger("uvicorn.error") |
|
|
|
|
|
class ServerState: |
|
""" |
|
Shared servers state that is available between all protocol instances. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self.total_requests = 0 |
|
self.connections: set[Protocols] = set() |
|
self.tasks: set[asyncio.Task[None]] = set() |
|
self.default_headers: list[tuple[bytes, bytes]] = [] |
|
|
|
|
|
class Server: |
|
def __init__(self, config: Config) -> None: |
|
self.config = config |
|
self.server_state = ServerState() |
|
|
|
self.started = False |
|
self.should_exit = False |
|
self.force_exit = False |
|
self.last_notified = 0.0 |
|
|
|
self._captured_signals: list[int] = [] |
|
|
|
def run(self, sockets: list[socket.socket] | None = None) -> None: |
|
self.config.setup_event_loop() |
|
return asyncio.run(self.serve(sockets=sockets)) |
|
|
|
async def serve(self, sockets: list[socket.socket] | None = None) -> None: |
|
with self.capture_signals(): |
|
await self._serve(sockets) |
|
|
|
async def _serve(self, sockets: list[socket.socket] | None = None) -> None: |
|
process_id = os.getpid() |
|
|
|
config = self.config |
|
if not config.loaded: |
|
config.load() |
|
|
|
self.lifespan = config.lifespan_class(config) |
|
|
|
message = "Started server process [%d]" |
|
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]" |
|
logger.info(message, process_id, extra={"color_message": color_message}) |
|
|
|
await self.startup(sockets=sockets) |
|
if self.should_exit: |
|
return |
|
await self.main_loop() |
|
await self.shutdown(sockets=sockets) |
|
|
|
message = "Finished server process [%d]" |
|
color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]" |
|
logger.info(message, process_id, extra={"color_message": color_message}) |
|
|
|
async def startup(self, sockets: list[socket.socket] | None = None) -> None: |
|
await self.lifespan.startup() |
|
if self.lifespan.should_exit: |
|
self.should_exit = True |
|
return |
|
|
|
config = self.config |
|
|
|
def create_protocol( |
|
_loop: asyncio.AbstractEventLoop | None = None, |
|
) -> asyncio.Protocol: |
|
return config.http_protocol_class( |
|
config=config, |
|
server_state=self.server_state, |
|
app_state=self.lifespan.state, |
|
_loop=_loop, |
|
) |
|
|
|
loop = asyncio.get_running_loop() |
|
|
|
listeners: Sequence[socket.SocketType] |
|
if sockets is not None: |
|
|
|
|
|
|
|
def _share_socket( |
|
sock: socket.SocketType, |
|
) -> socket.SocketType: |
|
|
|
|
|
from socket import fromshare |
|
|
|
sock_data = sock.share(os.getpid()) |
|
return fromshare(sock_data) |
|
|
|
self.servers: list[asyncio.base_events.Server] = [] |
|
for sock in sockets: |
|
is_windows = platform.system() == "Windows" |
|
if config.workers > 1 and is_windows: |
|
sock = _share_socket(sock) |
|
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog) |
|
self.servers.append(server) |
|
listeners = sockets |
|
|
|
elif config.fd is not None: |
|
|
|
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM) |
|
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog) |
|
assert server.sockets is not None |
|
listeners = server.sockets |
|
self.servers = [server] |
|
|
|
elif config.uds is not None: |
|
|
|
uds_perms = 0o666 |
|
if os.path.exists(config.uds): |
|
uds_perms = os.stat(config.uds).st_mode |
|
server = await loop.create_unix_server( |
|
create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog |
|
) |
|
os.chmod(config.uds, uds_perms) |
|
assert server.sockets is not None |
|
listeners = server.sockets |
|
self.servers = [server] |
|
|
|
else: |
|
|
|
try: |
|
server = await loop.create_server( |
|
create_protocol, |
|
host=config.host, |
|
port=config.port, |
|
ssl=config.ssl, |
|
backlog=config.backlog, |
|
) |
|
except OSError as exc: |
|
logger.error(exc) |
|
await self.lifespan.shutdown() |
|
sys.exit(1) |
|
|
|
assert server.sockets is not None |
|
listeners = server.sockets |
|
self.servers = [server] |
|
|
|
if sockets is None: |
|
self._log_started_message(listeners) |
|
else: |
|
|
|
|
|
pass |
|
|
|
self.started = True |
|
|
|
def _log_started_message(self, listeners: Sequence[socket.SocketType]) -> None: |
|
config = self.config |
|
|
|
if config.fd is not None: |
|
sock = listeners[0] |
|
logger.info( |
|
"Uvicorn running on socket %s (Press CTRL+C to quit)", |
|
sock.getsockname(), |
|
) |
|
|
|
elif config.uds is not None: |
|
logger.info("Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds) |
|
|
|
else: |
|
addr_format = "%s://%s:%d" |
|
host = "0.0.0.0" if config.host is None else config.host |
|
if ":" in host: |
|
|
|
addr_format = "%s://[%s]:%d" |
|
|
|
port = config.port |
|
if port == 0: |
|
port = listeners[0].getsockname()[1] |
|
|
|
protocol_name = "https" if config.ssl else "http" |
|
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)" |
|
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)" |
|
logger.info( |
|
message, |
|
protocol_name, |
|
host, |
|
port, |
|
extra={"color_message": color_message}, |
|
) |
|
|
|
async def main_loop(self) -> None: |
|
counter = 0 |
|
should_exit = await self.on_tick(counter) |
|
while not should_exit: |
|
counter += 1 |
|
counter = counter % 864000 |
|
await asyncio.sleep(0.1) |
|
should_exit = await self.on_tick(counter) |
|
|
|
async def on_tick(self, counter: int) -> bool: |
|
|
|
if counter % 10 == 0: |
|
current_time = time.time() |
|
current_date = formatdate(current_time, usegmt=True).encode() |
|
|
|
if self.config.date_header: |
|
date_header = [(b"date", current_date)] |
|
else: |
|
date_header = [] |
|
|
|
self.server_state.default_headers = date_header + self.config.encoded_headers |
|
|
|
|
|
if self.config.callback_notify is not None: |
|
if current_time - self.last_notified > self.config.timeout_notify: |
|
self.last_notified = current_time |
|
await self.config.callback_notify() |
|
|
|
|
|
if self.should_exit: |
|
return True |
|
|
|
max_requests = self.config.limit_max_requests |
|
if max_requests is not None and self.server_state.total_requests >= max_requests: |
|
logger.warning(f"Maximum request limit of {max_requests} exceeded. Terminating process.") |
|
return True |
|
|
|
return False |
|
|
|
async def shutdown(self, sockets: list[socket.socket] | None = None) -> None: |
|
logger.info("Shutting down") |
|
|
|
|
|
for server in self.servers: |
|
server.close() |
|
for sock in sockets or []: |
|
sock.close() |
|
|
|
|
|
for connection in list(self.server_state.connections): |
|
connection.shutdown() |
|
await asyncio.sleep(0.1) |
|
|
|
|
|
try: |
|
await asyncio.wait_for( |
|
self._wait_tasks_to_complete(), |
|
timeout=self.config.timeout_graceful_shutdown, |
|
) |
|
except asyncio.TimeoutError: |
|
logger.error( |
|
"Cancel %s running task(s), timeout graceful shutdown exceeded", |
|
len(self.server_state.tasks), |
|
) |
|
for t in self.server_state.tasks: |
|
t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded") |
|
|
|
|
|
if not self.force_exit: |
|
await self.lifespan.shutdown() |
|
|
|
async def _wait_tasks_to_complete(self) -> None: |
|
|
|
if self.server_state.connections and not self.force_exit: |
|
msg = "Waiting for connections to close. (CTRL+C to force quit)" |
|
logger.info(msg) |
|
while self.server_state.connections and not self.force_exit: |
|
await asyncio.sleep(0.1) |
|
|
|
|
|
if self.server_state.tasks and not self.force_exit: |
|
msg = "Waiting for background tasks to complete. (CTRL+C to force quit)" |
|
logger.info(msg) |
|
while self.server_state.tasks and not self.force_exit: |
|
await asyncio.sleep(0.1) |
|
|
|
for server in self.servers: |
|
await server.wait_closed() |
|
|
|
@contextlib.contextmanager |
|
def capture_signals(self) -> Generator[None, None, None]: |
|
|
|
if threading.current_thread() is not threading.main_thread(): |
|
yield |
|
return |
|
|
|
|
|
original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS} |
|
try: |
|
yield |
|
finally: |
|
for sig, handler in original_handlers.items(): |
|
signal.signal(sig, handler) |
|
|
|
|
|
|
|
for captured_signal in reversed(self._captured_signals): |
|
signal.raise_signal(captured_signal) |
|
|
|
def handle_exit(self, sig: int, frame: FrameType | None) -> None: |
|
self._captured_signals.append(sig) |
|
if self.should_exit and sig == signal.SIGINT: |
|
self.force_exit = True |
|
else: |
|
self.should_exit = True |
|
|