|
from __future__ import annotations |
|
|
|
import json |
|
import typing |
|
|
|
from starlette import status |
|
from starlette._utils import is_async_callable |
|
from starlette.concurrency import run_in_threadpool |
|
from starlette.exceptions import HTTPException |
|
from starlette.requests import Request |
|
from starlette.responses import PlainTextResponse, Response |
|
from starlette.types import Message, Receive, Scope, Send |
|
from starlette.websockets import WebSocket |
|
|
|
|
|
class HTTPEndpoint: |
|
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|
assert scope["type"] == "http" |
|
self.scope = scope |
|
self.receive = receive |
|
self.send = send |
|
self._allowed_methods = [ |
|
method |
|
for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS") |
|
if getattr(self, method.lower(), None) is not None |
|
] |
|
|
|
def __await__(self) -> typing.Generator[typing.Any, None, None]: |
|
return self.dispatch().__await__() |
|
|
|
async def dispatch(self) -> None: |
|
request = Request(self.scope, receive=self.receive) |
|
handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower() |
|
|
|
handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed) |
|
is_async = is_async_callable(handler) |
|
if is_async: |
|
response = await handler(request) |
|
else: |
|
response = await run_in_threadpool(handler, request) |
|
await response(self.scope, self.receive, self.send) |
|
|
|
async def method_not_allowed(self, request: Request) -> Response: |
|
|
|
|
|
|
|
headers = {"Allow": ", ".join(self._allowed_methods)} |
|
if "app" in self.scope: |
|
raise HTTPException(status_code=405, headers=headers) |
|
return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) |
|
|
|
|
|
class WebSocketEndpoint: |
|
encoding: str | None = None |
|
|
|
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|
assert scope["type"] == "websocket" |
|
self.scope = scope |
|
self.receive = receive |
|
self.send = send |
|
|
|
def __await__(self) -> typing.Generator[typing.Any, None, None]: |
|
return self.dispatch().__await__() |
|
|
|
async def dispatch(self) -> None: |
|
websocket = WebSocket(self.scope, receive=self.receive, send=self.send) |
|
await self.on_connect(websocket) |
|
|
|
close_code = status.WS_1000_NORMAL_CLOSURE |
|
|
|
try: |
|
while True: |
|
message = await websocket.receive() |
|
if message["type"] == "websocket.receive": |
|
data = await self.decode(websocket, message) |
|
await self.on_receive(websocket, data) |
|
elif message["type"] == "websocket.disconnect": |
|
close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE) |
|
break |
|
except Exception as exc: |
|
close_code = status.WS_1011_INTERNAL_ERROR |
|
raise exc |
|
finally: |
|
await self.on_disconnect(websocket, close_code) |
|
|
|
async def decode(self, websocket: WebSocket, message: Message) -> typing.Any: |
|
if self.encoding == "text": |
|
if "text" not in message: |
|
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) |
|
raise RuntimeError("Expected text websocket messages, but got bytes") |
|
return message["text"] |
|
|
|
elif self.encoding == "bytes": |
|
if "bytes" not in message: |
|
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) |
|
raise RuntimeError("Expected bytes websocket messages, but got text") |
|
return message["bytes"] |
|
|
|
elif self.encoding == "json": |
|
if message.get("text") is not None: |
|
text = message["text"] |
|
else: |
|
text = message["bytes"].decode("utf-8") |
|
|
|
try: |
|
return json.loads(text) |
|
except json.decoder.JSONDecodeError: |
|
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) |
|
raise RuntimeError("Malformed JSON data received.") |
|
|
|
assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}" |
|
return message["text"] if message.get("text") else message["bytes"] |
|
|
|
async def on_connect(self, websocket: WebSocket) -> None: |
|
"""Override to handle an incoming websocket connection""" |
|
await websocket.accept() |
|
|
|
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: |
|
"""Override to handle an incoming websocket message""" |
|
|
|
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: |
|
"""Override to handle a disconnecting websocket""" |
|
|