jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""Http related parsers and protocol."""
import asyncio
import sys
from typing import ( # noqa
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Iterable,
List,
NamedTuple,
Optional,
Union,
)
from multidict import CIMultiDict
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor
from .helpers import NO_EXTENSIONS
__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
MIN_PAYLOAD_FOR_WRITELINES = 2048
IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
# writelines is not safe for use
# on Python 3.12+ until 3.12.9
# on Python 3.13+ until 3.13.2
# and on older versions it not any faster than write
# CVE-2024-12254: https://github.com/python/cpython/pull/127656
class HttpVersion(NamedTuple):
major: int
minor: int
HttpVersion10 = HttpVersion(1, 0)
HttpVersion11 = HttpVersion(1, 1)
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
class StreamWriter(AbstractStreamWriter):
length: Optional[int] = None
chunked: bool = False
_eof: bool = False
_compress: Optional[ZLibCompressor] = None
def __init__(
self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
on_chunk_sent: _T_OnChunkSent = None,
on_headers_sent: _T_OnHeadersSent = None,
) -> None:
self._protocol = protocol
self.loop = loop
self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
self._headers_buf: Optional[bytes] = None
self._headers_written: bool = False
@property
def transport(self) -> Optional[asyncio.Transport]:
return self._protocol.transport
@property
def protocol(self) -> BaseProtocol:
return self._protocol
def enable_chunking(self) -> None:
self.chunked = True
def enable_compression(
self, encoding: str = "deflate", strategy: Optional[int] = None
) -> None:
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)
def _writelines(self, chunks: Iterable[bytes]) -> None:
size = 0
for chunk in chunks:
size += len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
transport.write(b"".join(chunks))
else:
transport.writelines(chunks)
def _write_chunked_payload(
self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
) -> None:
"""Write a chunk with proper chunked encoding."""
chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, chunk, b"\r\n"))
def _send_headers_with_payload(
self,
chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"],
is_eof: bool,
) -> None:
"""Send buffered headers with payload, coalescing into single write."""
# Mark headers as written
self._headers_written = True
headers_buf = self._headers_buf
self._headers_buf = None
if TYPE_CHECKING:
# Safe because callers (write() and write_eof()) only invoke this method
# after checking that self._headers_buf is truthy
assert headers_buf is not None
if not self.chunked:
# Non-chunked: coalesce headers with body
if chunk:
self._writelines((headers_buf, chunk))
else:
self._write(headers_buf)
return
# Coalesce headers with chunked data
if chunk:
chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
if is_eof:
self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
else:
self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n"))
elif is_eof:
self._writelines((headers_buf, b"0\r\n\r\n"))
else:
self._write(headers_buf)
async def write(
self,
chunk: Union[bytes, bytearray, memoryview],
*,
drain: bool = True,
LIMIT: int = 0x10000,
) -> None:
"""
Writes chunk of data to a stream.
write_eof() indicates end of stream.
writer can't be used after write_eof() method being called.
write() return drain future.
"""
if self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if isinstance(chunk, memoryview):
if chunk.nbytes != len(chunk):
# just reshape it
chunk = chunk.cast("c")
if self._compress is not None:
chunk = await self._compress.compress(chunk)
if not chunk:
return
if self.length is not None:
chunk_len = len(chunk)
if self.length >= chunk_len:
self.length = self.length - chunk_len
else:
chunk = chunk[: self.length]
self.length = 0
if not chunk:
return
# Handle buffered headers for small payload optimization
if self._headers_buf and not self._headers_written:
self._send_headers_with_payload(chunk, False)
if drain and self.buffer_size > LIMIT:
self.buffer_size = 0
await self.drain()
return
if chunk:
if self.chunked:
self._write_chunked_payload(chunk)
else:
self._write(chunk)
if drain and self.buffer_size > LIMIT:
self.buffer_size = 0
await self.drain()
async def write_headers(
self, status_line: str, headers: "CIMultiDict[str]"
) -> None:
"""Write headers to the stream."""
if self._on_headers_sent is not None:
await self._on_headers_sent(headers)
# status + headers
buf = _serialize_headers(status_line, headers)
self._headers_written = False
self._headers_buf = buf
def send_headers(self) -> None:
"""Force sending buffered headers if not already sent."""
if not self._headers_buf or self._headers_written:
return
self._headers_written = True
headers_buf = self._headers_buf
self._headers_buf = None
if TYPE_CHECKING:
# Safe because we only enter this block when self._headers_buf is truthy
assert headers_buf is not None
self._write(headers_buf)
def set_eof(self) -> None:
"""Indicate that the message is complete."""
if self._eof:
return
# If headers haven't been sent yet, send them now
# This handles the case where there's no body at all
if self._headers_buf and not self._headers_written:
self._headers_written = True
headers_buf = self._headers_buf
self._headers_buf = None
if TYPE_CHECKING:
# Safe because we only enter this block when self._headers_buf is truthy
assert headers_buf is not None
# Combine headers and chunked EOF marker in a single write
if self.chunked:
self._writelines((headers_buf, b"0\r\n\r\n"))
else:
self._write(headers_buf)
elif self.chunked and self._headers_written:
# Headers already sent, just send the final chunk marker
self._write(b"0\r\n\r\n")
self._eof = True
async def write_eof(self, chunk: bytes = b"") -> None:
if self._eof:
return
if chunk and self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
# Handle body/compression
if self._compress:
chunks: List[bytes] = []
chunks_len = 0
if chunk and (compressed_chunk := await self._compress.compress(chunk)):
chunks_len = len(compressed_chunk)
chunks.append(compressed_chunk)
flush_chunk = self._compress.flush()
chunks_len += len(flush_chunk)
chunks.append(flush_chunk)
assert chunks_len
# Send buffered headers with compressed data if not yet sent
if self._headers_buf and not self._headers_written:
self._headers_written = True
headers_buf = self._headers_buf
self._headers_buf = None
if self.chunked:
# Coalesce headers with compressed chunked data
chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
self._writelines(
(headers_buf, chunk_len_pre, *chunks, b"\r\n0\r\n\r\n")
)
else:
# Coalesce headers with compressed data
self._writelines((headers_buf, *chunks))
await self.drain()
self._eof = True
return
# Headers already sent, just write compressed data
if self.chunked:
chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
elif len(chunks) > 1:
self._writelines(chunks)
else:
self._write(chunks[0])
await self.drain()
self._eof = True
return
# No compression - send buffered headers if not yet sent
if self._headers_buf and not self._headers_written:
# Use helper to send headers with payload
self._send_headers_with_payload(chunk, True)
await self.drain()
self._eof = True
return
# Handle remaining body
if self.chunked:
if chunk:
# Write final chunk with EOF marker
self._writelines(
(f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n0\r\n\r\n")
)
else:
self._write(b"0\r\n\r\n")
await self.drain()
self._eof = True
return
if chunk:
self._write(chunk)
await self.drain()
self._eof = True
async def drain(self) -> None:
"""Flush the write buffer.
The intended use is to write
await w.write(data)
await w.drain()
"""
protocol = self._protocol
if protocol.transport is not None and protocol._paused:
await protocol._drain_helper()
def _safe_header(string: str) -> str:
if "\r" in string or "\n" in string:
raise ValueError(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
)
return string
def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
return line.encode("utf-8")
_serialize_headers = _py_serialize_headers
try:
import aiohttp._http_writer as _http_writer # type: ignore[import-not-found]
_c_serialize_headers = _http_writer._serialize_headers
if not NO_EXTENSIONS:
_serialize_headers = _c_serialize_headers
except ImportError:
pass