|
from __future__ import annotations |
|
|
|
import io |
|
import math |
|
import sys |
|
import typing |
|
import warnings |
|
|
|
import anyio |
|
from anyio.abc import ObjectReceiveStream, ObjectSendStream |
|
|
|
from starlette.types import Receive, Scope, Send |
|
|
|
warnings.warn( |
|
"starlette.middleware.wsgi is deprecated and will be removed in a future release. " |
|
"Please refer to https://github.com/abersheeran/a2wsgi as a replacement.", |
|
DeprecationWarning, |
|
) |
|
|
|
|
|
def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]: |
|
""" |
|
Builds a scope and request body into a WSGI environ object. |
|
""" |
|
|
|
script_name = scope.get("root_path", "").encode("utf8").decode("latin1") |
|
path_info = scope["path"].encode("utf8").decode("latin1") |
|
if path_info.startswith(script_name): |
|
path_info = path_info[len(script_name) :] |
|
|
|
environ = { |
|
"REQUEST_METHOD": scope["method"], |
|
"SCRIPT_NAME": script_name, |
|
"PATH_INFO": path_info, |
|
"QUERY_STRING": scope["query_string"].decode("ascii"), |
|
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}", |
|
"wsgi.version": (1, 0), |
|
"wsgi.url_scheme": scope.get("scheme", "http"), |
|
"wsgi.input": io.BytesIO(body), |
|
"wsgi.errors": sys.stdout, |
|
"wsgi.multithread": True, |
|
"wsgi.multiprocess": True, |
|
"wsgi.run_once": False, |
|
} |
|
|
|
|
|
server = scope.get("server") or ("localhost", 80) |
|
environ["SERVER_NAME"] = server[0] |
|
environ["SERVER_PORT"] = server[1] |
|
|
|
|
|
if scope.get("client"): |
|
environ["REMOTE_ADDR"] = scope["client"][0] |
|
|
|
|
|
for name, value in scope.get("headers", []): |
|
name = name.decode("latin1") |
|
if name == "content-length": |
|
corrected_name = "CONTENT_LENGTH" |
|
elif name == "content-type": |
|
corrected_name = "CONTENT_TYPE" |
|
else: |
|
corrected_name = f"HTTP_{name}".upper().replace("-", "_") |
|
|
|
|
|
value = value.decode("latin1") |
|
if corrected_name in environ: |
|
value = environ[corrected_name] + "," + value |
|
environ[corrected_name] = value |
|
return environ |
|
|
|
|
|
class WSGIMiddleware: |
|
def __init__(self, app: typing.Callable[..., typing.Any]) -> None: |
|
self.app = app |
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|
assert scope["type"] == "http" |
|
responder = WSGIResponder(self.app, scope) |
|
await responder(receive, send) |
|
|
|
|
|
class WSGIResponder: |
|
stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]] |
|
stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] |
|
|
|
def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None: |
|
self.app = app |
|
self.scope = scope |
|
self.status = None |
|
self.response_headers = None |
|
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf) |
|
self.response_started = False |
|
self.exc_info: typing.Any = None |
|
|
|
async def __call__(self, receive: Receive, send: Send) -> None: |
|
body = b"" |
|
more_body = True |
|
while more_body: |
|
message = await receive() |
|
body += message.get("body", b"") |
|
more_body = message.get("more_body", False) |
|
environ = build_environ(self.scope, body) |
|
|
|
async with anyio.create_task_group() as task_group: |
|
task_group.start_soon(self.sender, send) |
|
async with self.stream_send: |
|
await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) |
|
if self.exc_info is not None: |
|
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) |
|
|
|
async def sender(self, send: Send) -> None: |
|
async with self.stream_receive: |
|
async for message in self.stream_receive: |
|
await send(message) |
|
|
|
def start_response( |
|
self, |
|
status: str, |
|
response_headers: list[tuple[str, str]], |
|
exc_info: typing.Any = None, |
|
) -> None: |
|
self.exc_info = exc_info |
|
if not self.response_started: |
|
self.response_started = True |
|
status_code_string, _ = status.split(" ", 1) |
|
status_code = int(status_code_string) |
|
headers = [ |
|
(name.strip().encode("ascii").lower(), value.strip().encode("ascii")) |
|
for name, value in response_headers |
|
] |
|
anyio.from_thread.run( |
|
self.stream_send.send, |
|
{ |
|
"type": "http.response.start", |
|
"status": status_code, |
|
"headers": headers, |
|
}, |
|
) |
|
|
|
def wsgi( |
|
self, |
|
environ: dict[str, typing.Any], |
|
start_response: typing.Callable[..., typing.Any], |
|
) -> None: |
|
for chunk in self.app(environ, start_response): |
|
anyio.from_thread.run( |
|
self.stream_send.send, |
|
{"type": "http.response.body", "body": chunk, "more_body": True}, |
|
) |
|
|
|
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""}) |
|
|