|
from __future__ import annotations |
|
|
|
import typing |
|
|
|
from starlette._exception_handler import ( |
|
ExceptionHandlers, |
|
StatusHandlers, |
|
wrap_app_handling_exceptions, |
|
) |
|
from starlette.exceptions import HTTPException, WebSocketException |
|
from starlette.requests import Request |
|
from starlette.responses import PlainTextResponse, Response |
|
from starlette.types import ASGIApp, Receive, Scope, Send |
|
from starlette.websockets import WebSocket |
|
|
|
|
|
class ExceptionMiddleware: |
|
def __init__( |
|
self, |
|
app: ASGIApp, |
|
handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None, |
|
debug: bool = False, |
|
) -> None: |
|
self.app = app |
|
self.debug = debug |
|
self._status_handlers: StatusHandlers = {} |
|
self._exception_handlers: ExceptionHandlers = { |
|
HTTPException: self.http_exception, |
|
WebSocketException: self.websocket_exception, |
|
} |
|
if handlers is not None: |
|
for key, value in handlers.items(): |
|
self.add_exception_handler(key, value) |
|
|
|
def add_exception_handler( |
|
self, |
|
exc_class_or_status_code: int | type[Exception], |
|
handler: typing.Callable[[Request, Exception], Response], |
|
) -> None: |
|
if isinstance(exc_class_or_status_code, int): |
|
self._status_handlers[exc_class_or_status_code] = handler |
|
else: |
|
assert issubclass(exc_class_or_status_code, Exception) |
|
self._exception_handlers[exc_class_or_status_code] = handler |
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|
if scope["type"] not in ("http", "websocket"): |
|
await self.app(scope, receive, send) |
|
return |
|
|
|
scope["starlette.exception_handlers"] = ( |
|
self._exception_handlers, |
|
self._status_handlers, |
|
) |
|
|
|
conn: Request | WebSocket |
|
if scope["type"] == "http": |
|
conn = Request(scope, receive, send) |
|
else: |
|
conn = WebSocket(scope, receive, send) |
|
|
|
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) |
|
|
|
def http_exception(self, request: Request, exc: Exception) -> Response: |
|
assert isinstance(exc, HTTPException) |
|
if exc.status_code in {204, 304}: |
|
return Response(status_code=exc.status_code, headers=exc.headers) |
|
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers) |
|
|
|
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None: |
|
assert isinstance(exc, WebSocketException) |
|
await websocket.close(code=exc.code, reason=exc.reason) |
|
|