|
from __future__ import annotations |
|
|
|
import asyncio |
|
import email.utils |
|
import functools |
|
import http |
|
import inspect |
|
import logging |
|
import socket |
|
import warnings |
|
from collections.abc import Awaitable, Generator, Iterable, Sequence |
|
from types import TracebackType |
|
from typing import Any, Callable, Union, cast |
|
|
|
from ..asyncio.compatibility import asyncio_timeout |
|
from ..datastructures import Headers, HeadersLike, MultipleValuesError |
|
from ..exceptions import ( |
|
InvalidHandshake, |
|
InvalidHeader, |
|
InvalidMessage, |
|
InvalidOrigin, |
|
InvalidUpgrade, |
|
NegotiationError, |
|
) |
|
from ..extensions import Extension, ServerExtensionFactory |
|
from ..extensions.permessage_deflate import enable_server_permessage_deflate |
|
from ..headers import ( |
|
build_extension, |
|
parse_extension, |
|
parse_subprotocol, |
|
validate_subprotocols, |
|
) |
|
from ..http11 import SERVER |
|
from ..protocol import State |
|
from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol |
|
from .exceptions import AbortHandshake |
|
from .handshake import build_response, check_request |
|
from .http import read_request |
|
from .protocol import WebSocketCommonProtocol, broadcast |
|
|
|
|
|
__all__ = [ |
|
"broadcast", |
|
"serve", |
|
"unix_serve", |
|
"WebSocketServerProtocol", |
|
"WebSocketServer", |
|
] |
|
|
|
|
|
|
|
HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] |
|
|
|
HTTPResponse = tuple[StatusLike, HeadersLike, bytes] |
|
|
|
|
|
class WebSocketServerProtocol(WebSocketCommonProtocol): |
|
""" |
|
WebSocket server connection. |
|
|
|
:class:`WebSocketServerProtocol` provides :meth:`recv` and :meth:`send` |
|
coroutines for receiving and sending messages. |
|
|
|
It supports asynchronous iteration to receive messages:: |
|
|
|
async for message in websocket: |
|
await process(message) |
|
|
|
The iterator exits normally when the connection is closed with close code |
|
1000 (OK) or 1001 (going away) or without a close code. It raises |
|
a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection |
|
is closed with any other code. |
|
|
|
You may customize the opening handshake in a subclass by |
|
overriding :meth:`process_request` or :meth:`select_subprotocol`. |
|
|
|
Args: |
|
ws_server: WebSocket server that created this connection. |
|
|
|
See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``, |
|
``extensions``, ``subprotocols``, ``extra_headers``, and ``server_header``. |
|
|
|
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the |
|
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, |
|
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. |
|
|
|
""" |
|
|
|
is_client = False |
|
side = "server" |
|
|
|
def __init__( |
|
self, |
|
|
|
ws_handler: ( |
|
Callable[[WebSocketServerProtocol], Awaitable[Any]] |
|
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]] |
|
), |
|
ws_server: WebSocketServer, |
|
*, |
|
logger: LoggerLike | None = None, |
|
origins: Sequence[Origin | None] | None = None, |
|
extensions: Sequence[ServerExtensionFactory] | None = None, |
|
subprotocols: Sequence[Subprotocol] | None = None, |
|
extra_headers: HeadersLikeOrCallable | None = None, |
|
server_header: str | None = SERVER, |
|
process_request: ( |
|
Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None |
|
) = None, |
|
select_subprotocol: ( |
|
Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None |
|
) = None, |
|
open_timeout: float | None = 10, |
|
**kwargs: Any, |
|
) -> None: |
|
if logger is None: |
|
logger = logging.getLogger("websockets.server") |
|
super().__init__(logger=logger, **kwargs) |
|
|
|
if origins is not None and "" in origins: |
|
warnings.warn("use None instead of '' in origins", DeprecationWarning) |
|
origins = [None if origin == "" else origin for origin in origins] |
|
|
|
|
|
|
|
self.ws_handler = remove_path_argument(ws_handler) |
|
self.ws_server = ws_server |
|
self.origins = origins |
|
self.available_extensions = extensions |
|
self.available_subprotocols = subprotocols |
|
self.extra_headers = extra_headers |
|
self.server_header = server_header |
|
self._process_request = process_request |
|
self._select_subprotocol = select_subprotocol |
|
self.open_timeout = open_timeout |
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None: |
|
""" |
|
Register connection and initialize a task to handle it. |
|
|
|
""" |
|
super().connection_made(transport) |
|
|
|
|
|
|
|
|
|
self.ws_server.register(self) |
|
self.handler_task = self.loop.create_task(self.handler()) |
|
|
|
async def handler(self) -> None: |
|
""" |
|
Handle the lifecycle of a WebSocket connection. |
|
|
|
Since this method doesn't have a caller able to handle exceptions, it |
|
attempts to log relevant ones and guarantees that the TCP connection is |
|
closed before exiting. |
|
|
|
""" |
|
try: |
|
try: |
|
async with asyncio_timeout(self.open_timeout): |
|
await self.handshake( |
|
origins=self.origins, |
|
available_extensions=self.available_extensions, |
|
available_subprotocols=self.available_subprotocols, |
|
extra_headers=self.extra_headers, |
|
) |
|
except asyncio.TimeoutError: |
|
raise |
|
except ConnectionError: |
|
raise |
|
except Exception as exc: |
|
if isinstance(exc, AbortHandshake): |
|
status, headers, body = exc.status, exc.headers, exc.body |
|
elif isinstance(exc, InvalidOrigin): |
|
if self.debug: |
|
self.logger.debug("! invalid origin", exc_info=True) |
|
status, headers, body = ( |
|
http.HTTPStatus.FORBIDDEN, |
|
Headers(), |
|
f"Failed to open a WebSocket connection: {exc}.\n".encode(), |
|
) |
|
elif isinstance(exc, InvalidUpgrade): |
|
if self.debug: |
|
self.logger.debug("! invalid upgrade", exc_info=True) |
|
status, headers, body = ( |
|
http.HTTPStatus.UPGRADE_REQUIRED, |
|
Headers([("Upgrade", "websocket")]), |
|
( |
|
f"Failed to open a WebSocket connection: {exc}.\n" |
|
f"\n" |
|
f"You cannot access a WebSocket server directly " |
|
f"with a browser. You need a WebSocket client.\n" |
|
).encode(), |
|
) |
|
elif isinstance(exc, InvalidHandshake): |
|
if self.debug: |
|
self.logger.debug("! invalid handshake", exc_info=True) |
|
exc_chain = cast(BaseException, exc) |
|
exc_str = f"{exc_chain}" |
|
while exc_chain.__cause__ is not None: |
|
exc_chain = exc_chain.__cause__ |
|
exc_str += f"; {exc_chain}" |
|
status, headers, body = ( |
|
http.HTTPStatus.BAD_REQUEST, |
|
Headers(), |
|
f"Failed to open a WebSocket connection: {exc_str}.\n".encode(), |
|
) |
|
else: |
|
self.logger.error("opening handshake failed", exc_info=True) |
|
status, headers, body = ( |
|
http.HTTPStatus.INTERNAL_SERVER_ERROR, |
|
Headers(), |
|
( |
|
b"Failed to open a WebSocket connection.\n" |
|
b"See server log for more information.\n" |
|
), |
|
) |
|
|
|
headers.setdefault("Date", email.utils.formatdate(usegmt=True)) |
|
if self.server_header: |
|
headers.setdefault("Server", self.server_header) |
|
|
|
headers.setdefault("Content-Length", str(len(body))) |
|
headers.setdefault("Content-Type", "text/plain") |
|
headers.setdefault("Connection", "close") |
|
|
|
self.write_http_response(status, headers, body) |
|
self.logger.info( |
|
"connection rejected (%d %s)", status.value, status.phrase |
|
) |
|
await self.close_transport() |
|
return |
|
|
|
try: |
|
await self.ws_handler(self) |
|
except Exception: |
|
self.logger.error("connection handler failed", exc_info=True) |
|
if not self.closed: |
|
self.fail_connection(1011) |
|
raise |
|
|
|
try: |
|
await self.close() |
|
except ConnectionError: |
|
raise |
|
except Exception: |
|
self.logger.error("closing handshake failed", exc_info=True) |
|
raise |
|
|
|
except Exception: |
|
|
|
try: |
|
self.transport.close() |
|
except Exception: |
|
pass |
|
|
|
finally: |
|
|
|
|
|
|
|
|
|
self.ws_server.unregister(self) |
|
self.logger.info("connection closed") |
|
|
|
async def read_http_request(self) -> tuple[str, Headers]: |
|
""" |
|
Read request line and headers from the HTTP request. |
|
|
|
If the request contains a body, it may be read from ``self.reader`` |
|
after this coroutine returns. |
|
|
|
Raises: |
|
InvalidMessage: If the HTTP message is malformed or isn't an |
|
HTTP/1.1 GET request. |
|
|
|
""" |
|
try: |
|
path, headers = await read_request(self.reader) |
|
except asyncio.CancelledError: |
|
raise |
|
except Exception as exc: |
|
raise InvalidMessage("did not receive a valid HTTP request") from exc |
|
|
|
if self.debug: |
|
self.logger.debug("< GET %s HTTP/1.1", path) |
|
for key, value in headers.raw_items(): |
|
self.logger.debug("< %s: %s", key, value) |
|
|
|
self.path = path |
|
self.request_headers = headers |
|
|
|
return path, headers |
|
|
|
def write_http_response( |
|
self, status: http.HTTPStatus, headers: Headers, body: bytes | None = None |
|
) -> None: |
|
""" |
|
Write status line and headers to the HTTP response. |
|
|
|
This coroutine is also able to write a response body. |
|
|
|
""" |
|
self.response_headers = headers |
|
|
|
if self.debug: |
|
self.logger.debug("> HTTP/1.1 %d %s", status.value, status.phrase) |
|
for key, value in headers.raw_items(): |
|
self.logger.debug("> %s: %s", key, value) |
|
if body is not None: |
|
self.logger.debug("> [body] (%d bytes)", len(body)) |
|
|
|
|
|
|
|
response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" |
|
response += str(headers) |
|
|
|
self.transport.write(response.encode()) |
|
|
|
if body is not None: |
|
self.transport.write(body) |
|
|
|
async def process_request( |
|
self, path: str, request_headers: Headers |
|
) -> HTTPResponse | None: |
|
""" |
|
Intercept the HTTP request and return an HTTP response if appropriate. |
|
|
|
You may override this method in a :class:`WebSocketServerProtocol` |
|
subclass, for example: |
|
|
|
* to return an HTTP 200 OK response on a given path; then a load |
|
balancer can use this path for a health check; |
|
* to authenticate the request and return an HTTP 401 Unauthorized or an |
|
HTTP 403 Forbidden when authentication fails. |
|
|
|
You may also override this method with the ``process_request`` |
|
argument of :func:`serve` and :class:`WebSocketServerProtocol`. This |
|
is equivalent, except ``process_request`` won't have access to the |
|
protocol instance, so it can't store information for later use. |
|
|
|
:meth:`process_request` is expected to complete quickly. If it may run |
|
for a long time, then it should await :meth:`wait_closed` and exit if |
|
:meth:`wait_closed` completes, or else it could prevent the server |
|
from shutting down. |
|
|
|
Args: |
|
path: Request path, including optional query string. |
|
request_headers: Request headers. |
|
|
|
Returns: |
|
tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to |
|
continue the WebSocket handshake normally. |
|
|
|
An HTTP response, represented by a 3-uple of the response status, |
|
headers, and body, to abort the WebSocket handshake and return |
|
that HTTP response instead. |
|
|
|
""" |
|
if self._process_request is not None: |
|
response = self._process_request(path, request_headers) |
|
if isinstance(response, Awaitable): |
|
return await response |
|
else: |
|
|
|
warnings.warn( |
|
"declare process_request as a coroutine", DeprecationWarning |
|
) |
|
return response |
|
return None |
|
|
|
@staticmethod |
|
def process_origin( |
|
headers: Headers, origins: Sequence[Origin | None] | None = None |
|
) -> Origin | None: |
|
""" |
|
Handle the Origin HTTP request header. |
|
|
|
Args: |
|
headers: Request headers. |
|
origins: Optional list of acceptable origins. |
|
|
|
Raises: |
|
InvalidOrigin: If the origin isn't acceptable. |
|
|
|
""" |
|
|
|
|
|
try: |
|
origin = headers.get("Origin") |
|
except MultipleValuesError as exc: |
|
raise InvalidHeader("Origin", "multiple values") from exc |
|
if origin is not None: |
|
origin = cast(Origin, origin) |
|
if origins is not None: |
|
if origin not in origins: |
|
raise InvalidOrigin(origin) |
|
return origin |
|
|
|
@staticmethod |
|
def process_extensions( |
|
headers: Headers, |
|
available_extensions: Sequence[ServerExtensionFactory] | None, |
|
) -> tuple[str | None, list[Extension]]: |
|
""" |
|
Handle the Sec-WebSocket-Extensions HTTP request header. |
|
|
|
Accept or reject each extension proposed in the client request. |
|
Negotiate parameters for accepted extensions. |
|
|
|
Return the Sec-WebSocket-Extensions HTTP response header and the list |
|
of accepted extensions. |
|
|
|
:rfc:`6455` leaves the rules up to the specification of each |
|
:extension. |
|
|
|
To provide this level of flexibility, for each extension proposed by |
|
the client, we check for a match with each extension available in the |
|
server configuration. If no match is found, the extension is ignored. |
|
|
|
If several variants of the same extension are proposed by the client, |
|
it may be accepted several times, which won't make sense in general. |
|
Extensions must implement their own requirements. For this purpose, |
|
the list of previously accepted extensions is provided. |
|
|
|
This process doesn't allow the server to reorder extensions. It can |
|
only select a subset of the extensions proposed by the client. |
|
|
|
Other requirements, for example related to mandatory extensions or the |
|
order of extensions, may be implemented by overriding this method. |
|
|
|
Args: |
|
headers: Request headers. |
|
extensions: Optional list of supported extensions. |
|
|
|
Raises: |
|
InvalidHandshake: To abort the handshake with an HTTP 400 error. |
|
|
|
""" |
|
response_header_value: str | None = None |
|
|
|
extension_headers: list[ExtensionHeader] = [] |
|
accepted_extensions: list[Extension] = [] |
|
|
|
header_values = headers.get_all("Sec-WebSocket-Extensions") |
|
|
|
if header_values and available_extensions: |
|
parsed_header_values: list[ExtensionHeader] = sum( |
|
[parse_extension(header_value) for header_value in header_values], [] |
|
) |
|
|
|
for name, request_params in parsed_header_values: |
|
for ext_factory in available_extensions: |
|
|
|
if ext_factory.name != name: |
|
continue |
|
|
|
|
|
try: |
|
response_params, extension = ext_factory.process_request_params( |
|
request_params, accepted_extensions |
|
) |
|
except NegotiationError: |
|
continue |
|
|
|
|
|
extension_headers.append((name, response_params)) |
|
accepted_extensions.append(extension) |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
if extension_headers: |
|
response_header_value = build_extension(extension_headers) |
|
|
|
return response_header_value, accepted_extensions |
|
|
|
|
|
def process_subprotocol( |
|
self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None |
|
) -> Subprotocol | None: |
|
""" |
|
Handle the Sec-WebSocket-Protocol HTTP request header. |
|
|
|
Return Sec-WebSocket-Protocol HTTP response header, which is the same |
|
as the selected subprotocol. |
|
|
|
Args: |
|
headers: Request headers. |
|
available_subprotocols: Optional list of supported subprotocols. |
|
|
|
Raises: |
|
InvalidHandshake: To abort the handshake with an HTTP 400 error. |
|
|
|
""" |
|
subprotocol: Subprotocol | None = None |
|
|
|
header_values = headers.get_all("Sec-WebSocket-Protocol") |
|
|
|
if header_values and available_subprotocols: |
|
parsed_header_values: list[Subprotocol] = sum( |
|
[parse_subprotocol(header_value) for header_value in header_values], [] |
|
) |
|
|
|
subprotocol = self.select_subprotocol( |
|
parsed_header_values, available_subprotocols |
|
) |
|
|
|
return subprotocol |
|
|
|
def select_subprotocol( |
|
self, |
|
client_subprotocols: Sequence[Subprotocol], |
|
server_subprotocols: Sequence[Subprotocol], |
|
) -> Subprotocol | None: |
|
""" |
|
Pick a subprotocol among those supported by the client and the server. |
|
|
|
If several subprotocols are available, select the preferred subprotocol |
|
by giving equal weight to the preferences of the client and the server. |
|
|
|
If no subprotocol is available, proceed without a subprotocol. |
|
|
|
You may provide a ``select_subprotocol`` argument to :func:`serve` or |
|
:class:`WebSocketServerProtocol` to override this logic. For example, |
|
you could reject the handshake if the client doesn't support a |
|
particular subprotocol, rather than accept the handshake without that |
|
subprotocol. |
|
|
|
Args: |
|
client_subprotocols: List of subprotocols offered by the client. |
|
server_subprotocols: List of subprotocols available on the server. |
|
|
|
Returns: |
|
Selected subprotocol, if a common subprotocol was found. |
|
|
|
:obj:`None` to continue without a subprotocol. |
|
|
|
""" |
|
if self._select_subprotocol is not None: |
|
return self._select_subprotocol(client_subprotocols, server_subprotocols) |
|
|
|
subprotocols = set(client_subprotocols) & set(server_subprotocols) |
|
if not subprotocols: |
|
return None |
|
return sorted( |
|
subprotocols, |
|
key=lambda p: client_subprotocols.index(p) + server_subprotocols.index(p), |
|
)[0] |
|
|
|
async def handshake( |
|
self, |
|
origins: Sequence[Origin | None] | None = None, |
|
available_extensions: Sequence[ServerExtensionFactory] | None = None, |
|
available_subprotocols: Sequence[Subprotocol] | None = None, |
|
extra_headers: HeadersLikeOrCallable | None = None, |
|
) -> str: |
|
""" |
|
Perform the server side of the opening handshake. |
|
|
|
Args: |
|
origins: List of acceptable values of the Origin HTTP header; |
|
include :obj:`None` if the lack of an origin is acceptable. |
|
extensions: List of supported extensions, in order in which they |
|
should be tried. |
|
subprotocols: List of supported subprotocols, in order of |
|
decreasing preference. |
|
extra_headers: Arbitrary HTTP headers to add to the response when |
|
the handshake succeeds. |
|
|
|
Returns: |
|
path of the URI of the request. |
|
|
|
Raises: |
|
InvalidHandshake: If the handshake fails. |
|
|
|
""" |
|
path, request_headers = await self.read_http_request() |
|
|
|
|
|
|
|
early_response_awaitable = self.process_request(path, request_headers) |
|
if isinstance(early_response_awaitable, Awaitable): |
|
early_response = await early_response_awaitable |
|
else: |
|
|
|
warnings.warn("declare process_request as a coroutine", DeprecationWarning) |
|
early_response = early_response_awaitable |
|
|
|
|
|
if self.state is State.CLOSED: |
|
|
|
raise BrokenPipeError("connection closed during opening handshake") |
|
|
|
|
|
if not self.ws_server.is_serving(): |
|
early_response = ( |
|
http.HTTPStatus.SERVICE_UNAVAILABLE, |
|
[], |
|
b"Server is shutting down.\n", |
|
) |
|
|
|
if early_response is not None: |
|
raise AbortHandshake(*early_response) |
|
|
|
key = check_request(request_headers) |
|
|
|
self.origin = self.process_origin(request_headers, origins) |
|
|
|
extensions_header, self.extensions = self.process_extensions( |
|
request_headers, available_extensions |
|
) |
|
|
|
protocol_header = self.subprotocol = self.process_subprotocol( |
|
request_headers, available_subprotocols |
|
) |
|
|
|
response_headers = Headers() |
|
|
|
build_response(response_headers, key) |
|
|
|
if extensions_header is not None: |
|
response_headers["Sec-WebSocket-Extensions"] = extensions_header |
|
|
|
if protocol_header is not None: |
|
response_headers["Sec-WebSocket-Protocol"] = protocol_header |
|
|
|
if callable(extra_headers): |
|
extra_headers = extra_headers(path, self.request_headers) |
|
if extra_headers is not None: |
|
response_headers.update(extra_headers) |
|
|
|
response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) |
|
if self.server_header is not None: |
|
response_headers.setdefault("Server", self.server_header) |
|
|
|
self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) |
|
|
|
self.logger.info("connection open") |
|
|
|
self.connection_open() |
|
|
|
return path |
|
|
|
|
|
class WebSocketServer: |
|
""" |
|
WebSocket server returned by :func:`serve`. |
|
|
|
This class mirrors the API of :class:`~asyncio.Server`. |
|
|
|
It keeps track of WebSocket connections in order to close them properly |
|
when shutting down. |
|
|
|
Args: |
|
logger: Logger for this server. |
|
It defaults to ``logging.getLogger("websockets.server")``. |
|
See the :doc:`logging guide <../../topics/logging>` for details. |
|
|
|
""" |
|
|
|
def __init__(self, logger: LoggerLike | None = None) -> None: |
|
if logger is None: |
|
logger = logging.getLogger("websockets.server") |
|
self.logger = logger |
|
|
|
|
|
self.websockets: set[WebSocketServerProtocol] = set() |
|
|
|
|
|
self.close_task: asyncio.Task[None] | None = None |
|
|
|
|
|
self.closed_waiter: asyncio.Future[None] |
|
|
|
def wrap(self, server: asyncio.base_events.Server) -> None: |
|
""" |
|
Attach to a given :class:`~asyncio.Server`. |
|
|
|
Since :meth:`~asyncio.loop.create_server` doesn't support injecting a |
|
custom ``Server`` class, the easiest solution that doesn't rely on |
|
private :mod:`asyncio` APIs is to: |
|
|
|
- instantiate a :class:`WebSocketServer` |
|
- give the protocol factory a reference to that instance |
|
- call :meth:`~asyncio.loop.create_server` with the factory |
|
- attach the resulting :class:`~asyncio.Server` with this method |
|
|
|
""" |
|
self.server = server |
|
for sock in server.sockets: |
|
if sock.family == socket.AF_INET: |
|
name = "%s:%d" % sock.getsockname() |
|
elif sock.family == socket.AF_INET6: |
|
name = "[%s]:%d" % sock.getsockname()[:2] |
|
elif sock.family == socket.AF_UNIX: |
|
name = sock.getsockname() |
|
|
|
|
|
else: |
|
name = str(sock.getsockname()) |
|
self.logger.info("server listening on %s", name) |
|
|
|
|
|
|
|
self.closed_waiter = server.get_loop().create_future() |
|
|
|
def register(self, protocol: WebSocketServerProtocol) -> None: |
|
""" |
|
Register a connection with this server. |
|
|
|
""" |
|
self.websockets.add(protocol) |
|
|
|
def unregister(self, protocol: WebSocketServerProtocol) -> None: |
|
""" |
|
Unregister a connection with this server. |
|
|
|
""" |
|
self.websockets.remove(protocol) |
|
|
|
def close(self, close_connections: bool = True) -> None: |
|
""" |
|
Close the server. |
|
|
|
* Close the underlying :class:`~asyncio.Server`. |
|
* When ``close_connections`` is :obj:`True`, which is the default, |
|
close existing connections. Specifically: |
|
|
|
* Reject opening WebSocket connections with an HTTP 503 (service |
|
unavailable) error. This happens when the server accepted the TCP |
|
connection but didn't complete the opening handshake before closing. |
|
* Close open WebSocket connections with close code 1001 (going away). |
|
|
|
* Wait until all connection handlers terminate. |
|
|
|
:meth:`close` is idempotent. |
|
|
|
""" |
|
if self.close_task is None: |
|
self.close_task = self.get_loop().create_task( |
|
self._close(close_connections) |
|
) |
|
|
|
async def _close(self, close_connections: bool) -> None: |
|
""" |
|
Implementation of :meth:`close`. |
|
|
|
This calls :meth:`~asyncio.Server.close` on the underlying |
|
:class:`~asyncio.Server` object to stop accepting new connections and |
|
then closes open connections with close code 1001. |
|
|
|
""" |
|
self.logger.info("server closing") |
|
|
|
|
|
self.server.close() |
|
|
|
|
|
|
|
|
|
await asyncio.sleep(0) |
|
|
|
if close_connections: |
|
|
|
|
|
close_tasks = [ |
|
asyncio.create_task(websocket.close(1001)) |
|
for websocket in self.websockets |
|
if websocket.state is not State.CONNECTING |
|
] |
|
|
|
if close_tasks: |
|
await asyncio.wait(close_tasks) |
|
|
|
|
|
await self.server.wait_closed() |
|
|
|
|
|
|
|
if self.websockets: |
|
await asyncio.wait( |
|
[websocket.handler_task for websocket in self.websockets] |
|
) |
|
|
|
|
|
self.closed_waiter.set_result(None) |
|
|
|
self.logger.info("server closed") |
|
|
|
async def wait_closed(self) -> None: |
|
""" |
|
Wait until the server is closed. |
|
|
|
When :meth:`wait_closed` returns, all TCP connections are closed and |
|
all connection handlers have returned. |
|
|
|
To ensure a fast shutdown, a connection handler should always be |
|
awaiting at least one of: |
|
|
|
* :meth:`~WebSocketServerProtocol.recv`: when the connection is closed, |
|
it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; |
|
* :meth:`~WebSocketServerProtocol.wait_closed`: when the connection is |
|
closed, it returns. |
|
|
|
Then the connection handler is immediately notified of the shutdown; |
|
it can clean up and exit. |
|
|
|
""" |
|
await asyncio.shield(self.closed_waiter) |
|
|
|
def get_loop(self) -> asyncio.AbstractEventLoop: |
|
""" |
|
See :meth:`asyncio.Server.get_loop`. |
|
|
|
""" |
|
return self.server.get_loop() |
|
|
|
def is_serving(self) -> bool: |
|
""" |
|
See :meth:`asyncio.Server.is_serving`. |
|
|
|
""" |
|
return self.server.is_serving() |
|
|
|
async def start_serving(self) -> None: |
|
""" |
|
See :meth:`asyncio.Server.start_serving`. |
|
|
|
Typical use:: |
|
|
|
server = await serve(..., start_serving=False) |
|
# perform additional setup here... |
|
# ... then start the server |
|
await server.start_serving() |
|
|
|
""" |
|
await self.server.start_serving() |
|
|
|
async def serve_forever(self) -> None: |
|
""" |
|
See :meth:`asyncio.Server.serve_forever`. |
|
|
|
Typical use:: |
|
|
|
server = await serve(...) |
|
# this coroutine doesn't return |
|
# canceling it stops the server |
|
await server.serve_forever() |
|
|
|
This is an alternative to using :func:`serve` as an asynchronous context |
|
manager. Shutdown is triggered by canceling :meth:`serve_forever` |
|
instead of exiting a :func:`serve` context. |
|
|
|
""" |
|
await self.server.serve_forever() |
|
|
|
@property |
|
def sockets(self) -> Iterable[socket.socket]: |
|
""" |
|
See :attr:`asyncio.Server.sockets`. |
|
|
|
""" |
|
return self.server.sockets |
|
|
|
async def __aenter__(self) -> WebSocketServer: |
|
return self |
|
|
|
async def __aexit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_value: BaseException | None, |
|
traceback: TracebackType | None, |
|
) -> None: |
|
self.close() |
|
await self.wait_closed() |
|
|
|
|
|
class Serve: |
|
""" |
|
Start a WebSocket server listening on ``host`` and ``port``. |
|
|
|
Whenever a client connects, the server creates a |
|
:class:`WebSocketServerProtocol`, performs the opening handshake, and |
|
delegates to the connection handler, ``ws_handler``. |
|
|
|
The handler receives the :class:`WebSocketServerProtocol` and uses it to |
|
send and receive messages. |
|
|
|
Once the handler completes, either normally or with an exception, the |
|
server performs the closing handshake and closes the connection. |
|
|
|
Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object |
|
provides a :meth:`~WebSocketServer.close` method to shut down the server:: |
|
|
|
# set this future to exit the server |
|
stop = asyncio.get_running_loop().create_future() |
|
|
|
server = await serve(...) |
|
await stop |
|
server.close() |
|
await server.wait_closed() |
|
|
|
:func:`serve` can be used as an asynchronous context manager. Then, the |
|
server is shut down automatically when exiting the context:: |
|
|
|
# set this future to exit the server |
|
stop = asyncio.get_running_loop().create_future() |
|
|
|
async with serve(...): |
|
await stop |
|
|
|
Args: |
|
ws_handler: Connection handler. It receives the WebSocket connection, |
|
which is a :class:`WebSocketServerProtocol`, in argument. |
|
host: Network interfaces the server binds to. |
|
See :meth:`~asyncio.loop.create_server` for details. |
|
port: TCP port the server listens on. |
|
See :meth:`~asyncio.loop.create_server` for details. |
|
create_protocol: Factory for the :class:`asyncio.Protocol` managing |
|
the connection. It defaults to :class:`WebSocketServerProtocol`. |
|
Set it to a wrapper or a subclass to customize connection handling. |
|
logger: Logger for this server. |
|
It defaults to ``logging.getLogger("websockets.server")``. |
|
See the :doc:`logging guide <../../topics/logging>` for details. |
|
compression: The "permessage-deflate" extension is enabled by default. |
|
Set ``compression`` to :obj:`None` to disable it. See the |
|
:doc:`compression guide <../../topics/compression>` for details. |
|
origins: Acceptable values of the ``Origin`` header, for defending |
|
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` |
|
in the list if the lack of an origin is acceptable. |
|
extensions: List of supported extensions, in order in which they |
|
should be negotiated and run. |
|
subprotocols: List of supported subprotocols, in order of decreasing |
|
preference. |
|
extra_headers (HeadersLike | Callable[[str, Headers] | HeadersLike]): |
|
Arbitrary HTTP headers to add to the response. This can be |
|
a :data:`~websockets.datastructures.HeadersLike` or a callable |
|
taking the request path and headers in arguments and returning |
|
a :data:`~websockets.datastructures.HeadersLike`. |
|
server_header: Value of the ``Server`` response header. |
|
It defaults to ``"Python/x.y.z websockets/X.Y"``. |
|
Setting it to :obj:`None` removes the header. |
|
process_request (Callable[[str, Headers], \ |
|
Awaitable[tuple[StatusLike, HeadersLike, bytes] | None]] | None): |
|
Intercept HTTP request before the opening handshake. |
|
See :meth:`~WebSocketServerProtocol.process_request` for details. |
|
select_subprotocol: Select a subprotocol supported by the client. |
|
See :meth:`~WebSocketServerProtocol.select_subprotocol` for details. |
|
open_timeout: Timeout for opening connections in seconds. |
|
:obj:`None` disables the timeout. |
|
|
|
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the |
|
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, |
|
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. |
|
|
|
Any other keyword arguments are passed the event loop's |
|
:meth:`~asyncio.loop.create_server` method. |
|
|
|
For example: |
|
|
|
* You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. |
|
|
|
* You can set ``sock`` to a :obj:`~socket.socket` that you created |
|
outside of websockets. |
|
|
|
Returns: |
|
WebSocket server. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
|
|
ws_handler: ( |
|
Callable[[WebSocketServerProtocol], Awaitable[Any]] |
|
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]] |
|
), |
|
host: str | Sequence[str] | None = None, |
|
port: int | None = None, |
|
*, |
|
create_protocol: Callable[..., WebSocketServerProtocol] | None = None, |
|
logger: LoggerLike | None = None, |
|
compression: str | None = "deflate", |
|
origins: Sequence[Origin | None] | None = None, |
|
extensions: Sequence[ServerExtensionFactory] | None = None, |
|
subprotocols: Sequence[Subprotocol] | None = None, |
|
extra_headers: HeadersLikeOrCallable | None = None, |
|
server_header: str | None = SERVER, |
|
process_request: ( |
|
Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None |
|
) = None, |
|
select_subprotocol: ( |
|
Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None |
|
) = None, |
|
open_timeout: float | None = 10, |
|
ping_interval: float | None = 20, |
|
ping_timeout: float | None = 20, |
|
close_timeout: float | None = None, |
|
max_size: int | None = 2**20, |
|
max_queue: int | None = 2**5, |
|
read_limit: int = 2**16, |
|
write_limit: int = 2**16, |
|
**kwargs: Any, |
|
) -> None: |
|
|
|
timeout: float | None = kwargs.pop("timeout", None) |
|
if timeout is None: |
|
timeout = 10 |
|
else: |
|
warnings.warn("rename timeout to close_timeout", DeprecationWarning) |
|
|
|
if close_timeout is None: |
|
close_timeout = timeout |
|
|
|
|
|
klass: type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) |
|
if klass is None: |
|
klass = WebSocketServerProtocol |
|
else: |
|
warnings.warn("rename klass to create_protocol", DeprecationWarning) |
|
|
|
if create_protocol is None: |
|
create_protocol = klass |
|
|
|
|
|
legacy_recv: bool = kwargs.pop("legacy_recv", False) |
|
|
|
|
|
_loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) |
|
if _loop is None: |
|
loop = asyncio.get_event_loop() |
|
else: |
|
loop = _loop |
|
warnings.warn("remove loop argument", DeprecationWarning) |
|
|
|
ws_server = WebSocketServer(logger=logger) |
|
|
|
secure = kwargs.get("ssl") is not None |
|
|
|
if compression == "deflate": |
|
extensions = enable_server_permessage_deflate(extensions) |
|
elif compression is not None: |
|
raise ValueError(f"unsupported compression: {compression}") |
|
|
|
if subprotocols is not None: |
|
validate_subprotocols(subprotocols) |
|
|
|
|
|
|
|
create_protocol = cast(Callable[..., WebSocketServerProtocol], create_protocol) |
|
factory = functools.partial( |
|
create_protocol, |
|
|
|
|
|
|
|
remove_path_argument(ws_handler), |
|
ws_server, |
|
host=host, |
|
port=port, |
|
secure=secure, |
|
open_timeout=open_timeout, |
|
ping_interval=ping_interval, |
|
ping_timeout=ping_timeout, |
|
close_timeout=close_timeout, |
|
max_size=max_size, |
|
max_queue=max_queue, |
|
read_limit=read_limit, |
|
write_limit=write_limit, |
|
loop=_loop, |
|
legacy_recv=legacy_recv, |
|
origins=origins, |
|
extensions=extensions, |
|
subprotocols=subprotocols, |
|
extra_headers=extra_headers, |
|
server_header=server_header, |
|
process_request=process_request, |
|
select_subprotocol=select_subprotocol, |
|
logger=logger, |
|
) |
|
|
|
if kwargs.pop("unix", False): |
|
path: str | None = kwargs.pop("path", None) |
|
|
|
assert host is None and port is None |
|
create_server = functools.partial( |
|
loop.create_unix_server, factory, path, **kwargs |
|
) |
|
else: |
|
create_server = functools.partial( |
|
loop.create_server, factory, host, port, **kwargs |
|
) |
|
|
|
|
|
self._create_server = create_server |
|
self.ws_server = ws_server |
|
|
|
|
|
|
|
async def __aenter__(self) -> WebSocketServer: |
|
return await self |
|
|
|
async def __aexit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_value: BaseException | None, |
|
traceback: TracebackType | None, |
|
) -> None: |
|
self.ws_server.close() |
|
await self.ws_server.wait_closed() |
|
|
|
|
|
|
|
def __await__(self) -> Generator[Any, None, WebSocketServer]: |
|
|
|
return self.__await_impl__().__await__() |
|
|
|
async def __await_impl__(self) -> WebSocketServer: |
|
server = await self._create_server() |
|
self.ws_server.wrap(server) |
|
return self.ws_server |
|
|
|
|
|
|
|
__iter__ = __await__ |
|
|
|
|
|
serve = Serve |
|
|
|
|
|
def unix_serve( |
|
|
|
ws_handler: ( |
|
Callable[[WebSocketServerProtocol], Awaitable[Any]] |
|
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]] |
|
), |
|
path: str | None = None, |
|
**kwargs: Any, |
|
) -> Serve: |
|
""" |
|
Start a WebSocket server listening on a Unix socket. |
|
|
|
This function is identical to :func:`serve`, except the ``host`` and |
|
``port`` arguments are replaced by ``path``. It is only available on Unix. |
|
|
|
Unrecognized keyword arguments are passed the event loop's |
|
:meth:`~asyncio.loop.create_unix_server` method. |
|
|
|
It's useful for deploying a server behind a reverse proxy such as nginx. |
|
|
|
Args: |
|
path: File system path to the Unix socket. |
|
|
|
""" |
|
return serve(ws_handler, path=path, unix=True, **kwargs) |
|
|
|
|
|
def remove_path_argument( |
|
ws_handler: ( |
|
Callable[[WebSocketServerProtocol], Awaitable[Any]] |
|
| Callable[[WebSocketServerProtocol, str], Awaitable[Any]] |
|
), |
|
) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: |
|
try: |
|
inspect.signature(ws_handler).bind(None) |
|
except TypeError: |
|
try: |
|
inspect.signature(ws_handler).bind(None, "") |
|
except TypeError: |
|
|
|
pass |
|
else: |
|
|
|
warnings.warn("remove second argument of ws_handler", DeprecationWarning) |
|
|
|
async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: |
|
return await cast( |
|
Callable[[WebSocketServerProtocol, str], Awaitable[Any]], |
|
ws_handler, |
|
)(websocket, websocket.path) |
|
|
|
return _ws_handler |
|
|
|
return cast( |
|
Callable[[WebSocketServerProtocol], Awaitable[Any]], |
|
ws_handler, |
|
) |
|
|