|
from __future__ import annotations |
|
|
|
import typing |
|
|
|
import anyio |
|
|
|
from starlette._utils import collapse_excgroups |
|
from starlette.requests import ClientDisconnect, Request |
|
from starlette.responses import AsyncContentStream, Response |
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send |
|
|
|
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] |
|
DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]] |
|
T = typing.TypeVar("T") |
|
|
|
|
|
class _CachedRequest(Request): |
|
""" |
|
If the user calls Request.body() from their dispatch function |
|
we cache the entire request body in memory and pass that to downstream middlewares, |
|
but if they call Request.stream() then all we do is send an |
|
empty body so that downstream things don't hang forever. |
|
""" |
|
|
|
def __init__(self, scope: Scope, receive: Receive): |
|
super().__init__(scope, receive) |
|
self._wrapped_rcv_disconnected = False |
|
self._wrapped_rcv_consumed = False |
|
self._wrapped_rc_stream = self.stream() |
|
|
|
async def wrapped_receive(self) -> Message: |
|
|
|
if self._wrapped_rcv_disconnected: |
|
|
|
|
|
|
|
return {"type": "http.disconnect"} |
|
|
|
if self._wrapped_rcv_consumed: |
|
|
|
|
|
if self._is_disconnected: |
|
|
|
|
|
|
|
self._wrapped_rcv_disconnected = True |
|
return {"type": "http.disconnect"} |
|
|
|
|
|
msg = await self.receive() |
|
if msg["type"] != "http.disconnect": |
|
|
|
|
|
raise RuntimeError(f"Unexpected message received: {msg['type']}") |
|
self._wrapped_rcv_disconnected = True |
|
return msg |
|
|
|
|
|
if getattr(self, "_body", None) is not None: |
|
|
|
self._wrapped_rcv_consumed = True |
|
return { |
|
"type": "http.request", |
|
"body": self._body, |
|
"more_body": False, |
|
} |
|
elif self._stream_consumed: |
|
|
|
|
|
|
|
self._wrapped_rcv_consumed = True |
|
return { |
|
"type": "http.request", |
|
"body": b"", |
|
"more_body": False, |
|
} |
|
else: |
|
|
|
try: |
|
stream = self.stream() |
|
chunk = await stream.__anext__() |
|
self._wrapped_rcv_consumed = self._stream_consumed |
|
return { |
|
"type": "http.request", |
|
"body": chunk, |
|
"more_body": not self._stream_consumed, |
|
} |
|
except ClientDisconnect: |
|
self._wrapped_rcv_disconnected = True |
|
return {"type": "http.disconnect"} |
|
|
|
|
|
class BaseHTTPMiddleware: |
|
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None: |
|
self.app = app |
|
self.dispatch_func = self.dispatch if dispatch is None else dispatch |
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|
if scope["type"] != "http": |
|
await self.app(scope, receive, send) |
|
return |
|
|
|
request = _CachedRequest(scope, receive) |
|
wrapped_receive = request.wrapped_receive |
|
response_sent = anyio.Event() |
|
app_exc: Exception | None = None |
|
exception_already_raised = False |
|
|
|
async def call_next(request: Request) -> Response: |
|
async def receive_or_disconnect() -> Message: |
|
if response_sent.is_set(): |
|
return {"type": "http.disconnect"} |
|
|
|
async with anyio.create_task_group() as task_group: |
|
|
|
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: |
|
result = await func() |
|
task_group.cancel_scope.cancel() |
|
return result |
|
|
|
task_group.start_soon(wrap, response_sent.wait) |
|
message = await wrap(wrapped_receive) |
|
|
|
if response_sent.is_set(): |
|
return {"type": "http.disconnect"} |
|
|
|
return message |
|
|
|
async def send_no_error(message: Message) -> None: |
|
try: |
|
await send_stream.send(message) |
|
except anyio.BrokenResourceError: |
|
|
|
return |
|
|
|
async def coro() -> None: |
|
nonlocal app_exc |
|
|
|
with send_stream: |
|
try: |
|
await self.app(scope, receive_or_disconnect, send_no_error) |
|
except Exception as exc: |
|
app_exc = exc |
|
|
|
task_group.start_soon(coro) |
|
|
|
try: |
|
message = await recv_stream.receive() |
|
info = message.get("info", None) |
|
if message["type"] == "http.response.debug" and info is not None: |
|
message = await recv_stream.receive() |
|
except anyio.EndOfStream: |
|
if app_exc is not None: |
|
nonlocal exception_already_raised |
|
exception_already_raised = True |
|
raise app_exc |
|
raise RuntimeError("No response returned.") |
|
|
|
assert message["type"] == "http.response.start" |
|
|
|
async def body_stream() -> typing.AsyncGenerator[bytes, None]: |
|
async for message in recv_stream: |
|
assert message["type"] == "http.response.body" |
|
body = message.get("body", b"") |
|
if body: |
|
yield body |
|
if not message.get("more_body", False): |
|
break |
|
|
|
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info) |
|
response.raw_headers = message["headers"] |
|
return response |
|
|
|
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream() |
|
send_stream, recv_stream = streams |
|
with recv_stream, send_stream, collapse_excgroups(): |
|
async with anyio.create_task_group() as task_group: |
|
response = await self.dispatch_func(request, call_next) |
|
await response(scope, wrapped_receive, send) |
|
response_sent.set() |
|
recv_stream.close() |
|
if app_exc is not None and not exception_already_raised: |
|
raise app_exc |
|
|
|
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: |
|
raise NotImplementedError() |
|
|
|
|
|
class _StreamingResponse(Response): |
|
def __init__( |
|
self, |
|
content: AsyncContentStream, |
|
status_code: int = 200, |
|
headers: typing.Mapping[str, str] | None = None, |
|
media_type: str | None = None, |
|
info: typing.Mapping[str, typing.Any] | None = None, |
|
) -> None: |
|
self.info = info |
|
self.body_iterator = content |
|
self.status_code = status_code |
|
self.media_type = media_type |
|
self.init_headers(headers) |
|
self.background = None |
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|
if self.info is not None: |
|
await send({"type": "http.response.debug", "info": self.info}) |
|
await send( |
|
{ |
|
"type": "http.response.start", |
|
"status": self.status_code, |
|
"headers": self.raw_headers, |
|
} |
|
) |
|
|
|
async for chunk in self.body_iterator: |
|
await send({"type": "http.response.body", "body": chunk, "more_body": True}) |
|
|
|
await send({"type": "http.response.body", "body": b"", "more_body": False}) |
|
|
|
if self.background: |
|
await self.background() |
|
|