File size: 6,601 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from __future__ import annotations
import http
import ssl as ssl_module
import urllib.parse
from typing import Any, Awaitable, Callable, Literal
from werkzeug.exceptions import NotFound
from werkzeug.routing import Map, RequestRedirect
from ..http11 import Request, Response
from .server import Server, ServerConnection, serve
__all__ = ["route", "unix_route", "Router"]
class Router:
"""WebSocket router supporting :func:`route`."""
def __init__(
self,
url_map: Map,
server_name: str | None = None,
url_scheme: str = "ws",
) -> None:
self.url_map = url_map
self.server_name = server_name
self.url_scheme = url_scheme
for rule in self.url_map.iter_rules():
rule.websocket = True
def get_server_name(self, connection: ServerConnection, request: Request) -> str:
if self.server_name is None:
return request.headers["Host"]
else:
return self.server_name
def redirect(self, connection: ServerConnection, url: str) -> Response:
response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}")
response.headers["Location"] = url
return response
def not_found(self, connection: ServerConnection) -> Response:
return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found")
def route_request(
self, connection: ServerConnection, request: Request
) -> Response | None:
"""Route incoming request."""
url_map_adapter = self.url_map.bind(
server_name=self.get_server_name(connection, request),
url_scheme=self.url_scheme,
)
try:
parsed = urllib.parse.urlparse(request.path)
handler, kwargs = url_map_adapter.match(
path_info=parsed.path,
query_args=parsed.query,
)
except RequestRedirect as redirect:
return self.redirect(connection, redirect.new_url)
except NotFound:
return self.not_found(connection)
connection.handler, connection.handler_kwargs = handler, kwargs
return None
async def handler(self, connection: ServerConnection) -> None:
"""Handle a connection."""
return await connection.handler(connection, **connection.handler_kwargs)
def route(
url_map: Map,
*args: Any,
server_name: str | None = None,
ssl: ssl_module.SSLContext | Literal[True] | None = None,
create_router: type[Router] | None = None,
**kwargs: Any,
) -> Awaitable[Server]:
"""
Create a WebSocket server dispatching connections to different handlers.
This feature requires the third-party library `werkzeug`_:
.. code-block:: console
$ pip install werkzeug
.. _werkzeug: https://werkzeug.palletsprojects.com/
:func:`route` accepts the same arguments as
:func:`~websockets.sync.server.serve`, except as described below.
The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns
to connection handlers. In addition to the connection, handlers receive
parameters captured in the URL as keyword arguments.
Here's an example::
from websockets.asyncio.router import route
from werkzeug.routing import Map, Rule
async def channel_handler(websocket, channel_id):
...
url_map = Map([
Rule("/channel/<uuid:channel_id>", endpoint=channel_handler),
...
])
# set this future to exit the server
stop = asyncio.get_running_loop().create_future()
async with route(url_map, ...) as server:
await stop
Refer to the documentation of :mod:`werkzeug.routing` for details.
If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map,
when the server runs behind a reverse proxy that modifies the ``Host``
header or terminates TLS, you need additional configuration:
* Set ``server_name`` to the name of the server as seen by clients. When not
provided, websockets uses the value of the ``Host`` header.
* Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling
TLS. Under the hood, this bind the URL map with a ``url_scheme`` of
``wss://`` instead of ``ws://``.
There is no need to specify ``websocket=True`` in each rule. It is added
automatically.
Args:
url_map: Mapping of URL patterns to connection handlers.
server_name: Name of the server as seen by clients. If :obj:`None`,
websockets uses the value of the ``Host`` header.
ssl: Configuration for enabling TLS on the connection. Set it to
:obj:`True` if a reverse proxy terminates TLS connections.
create_router: Factory for the :class:`Router` dispatching requests to
handlers. Set it to a wrapper or a subclass to customize routing.
"""
url_scheme = "ws" if ssl is None else "wss"
if ssl is not True and ssl is not None:
kwargs["ssl"] = ssl
if create_router is None:
create_router = Router
router = create_router(url_map, server_name, url_scheme)
_process_request: (
Callable[
[ServerConnection, Request],
Awaitable[Response | None] | Response | None,
]
| None
) = kwargs.pop("process_request", None)
if _process_request is None:
process_request: Callable[
[ServerConnection, Request],
Awaitable[Response | None] | Response | None,
] = router.route_request
else:
async def process_request(
connection: ServerConnection, request: Request
) -> Response | None:
response = _process_request(connection, request)
if isinstance(response, Awaitable):
response = await response
if response is not None:
return response
return router.route_request(connection, request)
return serve(router.handler, *args, process_request=process_request, **kwargs)
def unix_route(
url_map: Map,
path: str | None = None,
**kwargs: Any,
) -> Awaitable[Server]:
"""
Create a WebSocket Unix server dispatching connections to different handlers.
:func:`unix_route` combines the behaviors of :func:`route` and
:func:`~websockets.asyncio.server.unix_serve`.
Args:
url_map: Mapping of URL patterns to connection handlers.
path: File system path to the Unix socket.
"""
return route(url_map, unix=True, path=path, **kwargs)
|