|
"""Reader for WebSocket protocol versions 13 and 8.""" |
|
|
|
import asyncio |
|
import builtins |
|
from collections import deque |
|
from typing import Deque, Final, Optional, Set, Tuple, Union |
|
|
|
from ..base_protocol import BaseProtocol |
|
from ..compression_utils import ZLibDecompressor |
|
from ..helpers import _EXC_SENTINEL, set_exception |
|
from ..streams import EofStream |
|
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask |
|
from .models import ( |
|
WS_DEFLATE_TRAILING, |
|
WebSocketError, |
|
WSCloseCode, |
|
WSMessage, |
|
WSMsgType, |
|
) |
|
|
|
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode} |
|
|
|
|
|
|
|
READ_HEADER = 1 |
|
READ_PAYLOAD_LENGTH = 2 |
|
READ_PAYLOAD_MASK = 3 |
|
READ_PAYLOAD = 4 |
|
|
|
WS_MSG_TYPE_BINARY = WSMsgType.BINARY |
|
WS_MSG_TYPE_TEXT = WSMsgType.TEXT |
|
|
|
|
|
OP_CODE_NOT_SET = -1 |
|
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value |
|
OP_CODE_TEXT = WSMsgType.TEXT.value |
|
OP_CODE_BINARY = WSMsgType.BINARY.value |
|
OP_CODE_CLOSE = WSMsgType.CLOSE.value |
|
OP_CODE_PING = WSMsgType.PING.value |
|
OP_CODE_PONG = WSMsgType.PONG.value |
|
|
|
EMPTY_FRAME_ERROR = (True, b"") |
|
EMPTY_FRAME = (False, b"") |
|
|
|
COMPRESSED_NOT_SET = -1 |
|
COMPRESSED_FALSE = 0 |
|
COMPRESSED_TRUE = 1 |
|
|
|
TUPLE_NEW = tuple.__new__ |
|
|
|
cython_int = int |
|
|
|
|
|
class WebSocketDataQueue: |
|
"""WebSocketDataQueue resumes and pauses an underlying stream. |
|
|
|
It is a destination for WebSocket data. |
|
""" |
|
|
|
def __init__( |
|
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop |
|
) -> None: |
|
self._size = 0 |
|
self._protocol = protocol |
|
self._limit = limit * 2 |
|
self._loop = loop |
|
self._eof = False |
|
self._waiter: Optional[asyncio.Future[None]] = None |
|
self._exception: Union[BaseException, None] = None |
|
self._buffer: Deque[Tuple[WSMessage, int]] = deque() |
|
self._get_buffer = self._buffer.popleft |
|
self._put_buffer = self._buffer.append |
|
|
|
def is_eof(self) -> bool: |
|
return self._eof |
|
|
|
def exception(self) -> Optional[BaseException]: |
|
return self._exception |
|
|
|
def set_exception( |
|
self, |
|
exc: BaseException, |
|
exc_cause: builtins.BaseException = _EXC_SENTINEL, |
|
) -> None: |
|
self._eof = True |
|
self._exception = exc |
|
if (waiter := self._waiter) is not None: |
|
self._waiter = None |
|
set_exception(waiter, exc, exc_cause) |
|
|
|
def _release_waiter(self) -> None: |
|
if (waiter := self._waiter) is None: |
|
return |
|
self._waiter = None |
|
if not waiter.done(): |
|
waiter.set_result(None) |
|
|
|
def feed_eof(self) -> None: |
|
self._eof = True |
|
self._release_waiter() |
|
self._exception = None |
|
|
|
def feed_data(self, data: "WSMessage", size: "cython_int") -> None: |
|
self._size += size |
|
self._put_buffer((data, size)) |
|
self._release_waiter() |
|
if self._size > self._limit and not self._protocol._reading_paused: |
|
self._protocol.pause_reading() |
|
|
|
async def read(self) -> WSMessage: |
|
if not self._buffer and not self._eof: |
|
assert not self._waiter |
|
self._waiter = self._loop.create_future() |
|
try: |
|
await self._waiter |
|
except (asyncio.CancelledError, asyncio.TimeoutError): |
|
self._waiter = None |
|
raise |
|
return self._read_from_buffer() |
|
|
|
def _read_from_buffer(self) -> WSMessage: |
|
if self._buffer: |
|
data, size = self._get_buffer() |
|
self._size -= size |
|
if self._size < self._limit and self._protocol._reading_paused: |
|
self._protocol.resume_reading() |
|
return data |
|
if self._exception is not None: |
|
raise self._exception |
|
raise EofStream |
|
|
|
|
|
class WebSocketReader: |
|
def __init__( |
|
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True |
|
) -> None: |
|
self.queue = queue |
|
self._max_msg_size = max_msg_size |
|
|
|
self._exc: Optional[Exception] = None |
|
self._partial = bytearray() |
|
self._state = READ_HEADER |
|
|
|
self._opcode: int = OP_CODE_NOT_SET |
|
self._frame_fin = False |
|
self._frame_opcode: int = OP_CODE_NOT_SET |
|
self._payload_fragments: list[bytes] = [] |
|
self._frame_payload_len = 0 |
|
|
|
self._tail: bytes = b"" |
|
self._has_mask = False |
|
self._frame_mask: Optional[bytes] = None |
|
self._payload_bytes_to_read = 0 |
|
self._payload_len_flag = 0 |
|
self._compressed: int = COMPRESSED_NOT_SET |
|
self._decompressobj: Optional[ZLibDecompressor] = None |
|
self._compress = compress |
|
|
|
def feed_eof(self) -> None: |
|
self.queue.feed_eof() |
|
|
|
|
|
|
|
|
|
def feed_data( |
|
self, data: Union[bytes, bytearray, memoryview] |
|
) -> Tuple[bool, bytes]: |
|
if type(data) is not bytes: |
|
data = bytes(data) |
|
|
|
if self._exc is not None: |
|
return True, data |
|
|
|
try: |
|
self._feed_data(data) |
|
except Exception as exc: |
|
self._exc = exc |
|
set_exception(self.queue, exc) |
|
return EMPTY_FRAME_ERROR |
|
|
|
return EMPTY_FRAME |
|
|
|
def _handle_frame( |
|
self, |
|
fin: bool, |
|
opcode: Union[int, cython_int], |
|
payload: Union[bytes, bytearray], |
|
compressed: Union[int, cython_int], |
|
) -> None: |
|
msg: WSMessage |
|
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}: |
|
|
|
if not fin: |
|
|
|
if opcode != OP_CODE_CONTINUATION: |
|
self._opcode = opcode |
|
self._partial += payload |
|
if self._max_msg_size and len(self._partial) >= self._max_msg_size: |
|
raise WebSocketError( |
|
WSCloseCode.MESSAGE_TOO_BIG, |
|
f"Message size {len(self._partial)} " |
|
f"exceeds limit {self._max_msg_size}", |
|
) |
|
return |
|
|
|
has_partial = bool(self._partial) |
|
if opcode == OP_CODE_CONTINUATION: |
|
if self._opcode == OP_CODE_NOT_SET: |
|
raise WebSocketError( |
|
WSCloseCode.PROTOCOL_ERROR, |
|
"Continuation frame for non started message", |
|
) |
|
opcode = self._opcode |
|
self._opcode = OP_CODE_NOT_SET |
|
|
|
|
|
elif has_partial: |
|
raise WebSocketError( |
|
WSCloseCode.PROTOCOL_ERROR, |
|
"The opcode in non-fin frame is expected " |
|
f"to be zero, got {opcode!r}", |
|
) |
|
|
|
assembled_payload: Union[bytes, bytearray] |
|
if has_partial: |
|
assembled_payload = self._partial + payload |
|
self._partial.clear() |
|
else: |
|
assembled_payload = payload |
|
|
|
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: |
|
raise WebSocketError( |
|
WSCloseCode.MESSAGE_TOO_BIG, |
|
f"Message size {len(assembled_payload)} " |
|
f"exceeds limit {self._max_msg_size}", |
|
) |
|
|
|
|
|
|
|
if compressed: |
|
if not self._decompressobj: |
|
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) |
|
|
|
|
|
|
|
|
|
|
|
payload_merged = self._decompressobj.decompress_sync( |
|
assembled_payload + WS_DEFLATE_TRAILING, |
|
( |
|
self._max_msg_size + 1 |
|
if self._max_msg_size |
|
else self._max_msg_size |
|
), |
|
) |
|
if self._max_msg_size and len(payload_merged) > self._max_msg_size: |
|
raise WebSocketError( |
|
WSCloseCode.MESSAGE_TOO_BIG, |
|
f"Decompressed message exceeds size limit {self._max_msg_size}", |
|
) |
|
elif type(assembled_payload) is bytes: |
|
payload_merged = assembled_payload |
|
else: |
|
payload_merged = bytes(assembled_payload) |
|
|
|
if opcode == OP_CODE_TEXT: |
|
try: |
|
text = payload_merged.decode("utf-8") |
|
except UnicodeDecodeError as exc: |
|
raise WebSocketError( |
|
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" |
|
) from exc |
|
|
|
|
|
|
|
|
|
|
|
self.queue.feed_data( |
|
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")), |
|
len(payload_merged), |
|
) |
|
else: |
|
self.queue.feed_data( |
|
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")), |
|
len(payload_merged), |
|
) |
|
elif opcode == OP_CODE_CLOSE: |
|
if len(payload) >= 2: |
|
close_code = UNPACK_CLOSE_CODE(payload[:2])[0] |
|
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: |
|
raise WebSocketError( |
|
WSCloseCode.PROTOCOL_ERROR, |
|
f"Invalid close code: {close_code}", |
|
) |
|
try: |
|
close_message = payload[2:].decode("utf-8") |
|
except UnicodeDecodeError as exc: |
|
raise WebSocketError( |
|
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" |
|
) from exc |
|
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message)) |
|
elif payload: |
|
raise WebSocketError( |
|
WSCloseCode.PROTOCOL_ERROR, |
|
f"Invalid close frame: {fin} {opcode} {payload!r}", |
|
) |
|
else: |
|
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, "")) |
|
|
|
self.queue.feed_data(msg, 0) |
|
elif opcode == OP_CODE_PING: |
|
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, "")) |
|
self.queue.feed_data(msg, len(payload)) |
|
elif opcode == OP_CODE_PONG: |
|
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, "")) |
|
self.queue.feed_data(msg, len(payload)) |
|
else: |
|
raise WebSocketError( |
|
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" |
|
) |
|
|
|
def _feed_data(self, data: bytes) -> None: |
|
"""Return the next frame from the socket.""" |
|
if self._tail: |
|
data, self._tail = self._tail + data, b"" |
|
|
|
start_pos: int = 0 |
|
data_len = len(data) |
|
data_cstr = data |
|
|
|
while True: |
|
|
|
if self._state == READ_HEADER: |
|
if data_len - start_pos < 2: |
|
break |
|
first_byte = data_cstr[start_pos] |
|
second_byte = data_cstr[start_pos + 1] |
|
start_pos += 2 |
|
|
|
fin = (first_byte >> 7) & 1 |
|
rsv1 = (first_byte >> 6) & 1 |
|
rsv2 = (first_byte >> 5) & 1 |
|
rsv3 = (first_byte >> 4) & 1 |
|
opcode = first_byte & 0xF |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if rsv2 or rsv3 or (rsv1 and not self._compress): |
|
raise WebSocketError( |
|
WSCloseCode.PROTOCOL_ERROR, |
|
"Received frame with non-zero reserved bits", |
|
) |
|
|
|
if opcode > 0x7 and fin == 0: |
|
raise WebSocketError( |
|
WSCloseCode.PROTOCOL_ERROR, |
|
"Received fragmented control frame", |
|
) |
|
|
|
has_mask = (second_byte >> 7) & 1 |
|
length = second_byte & 0x7F |
|
|
|
|
|
|
|
if opcode > 0x7 and length > 125: |
|
raise WebSocketError( |
|
WSCloseCode.PROTOCOL_ERROR, |
|
"Control frame payload cannot be larger than 125 bytes", |
|
) |
|
|
|
|
|
|
|
|
|
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET: |
|
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE |
|
elif rsv1: |
|
raise WebSocketError( |
|
WSCloseCode.PROTOCOL_ERROR, |
|
"Received frame with non-zero reserved bits", |
|
) |
|
|
|
self._frame_fin = bool(fin) |
|
self._frame_opcode = opcode |
|
self._has_mask = bool(has_mask) |
|
self._payload_len_flag = length |
|
self._state = READ_PAYLOAD_LENGTH |
|
|
|
|
|
if self._state == READ_PAYLOAD_LENGTH: |
|
len_flag = self._payload_len_flag |
|
if len_flag == 126: |
|
if data_len - start_pos < 2: |
|
break |
|
first_byte = data_cstr[start_pos] |
|
second_byte = data_cstr[start_pos + 1] |
|
start_pos += 2 |
|
self._payload_bytes_to_read = first_byte << 8 | second_byte |
|
elif len_flag > 126: |
|
if data_len - start_pos < 8: |
|
break |
|
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0] |
|
start_pos += 8 |
|
else: |
|
self._payload_bytes_to_read = len_flag |
|
|
|
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD |
|
|
|
|
|
if self._state == READ_PAYLOAD_MASK: |
|
if data_len - start_pos < 4: |
|
break |
|
self._frame_mask = data_cstr[start_pos : start_pos + 4] |
|
start_pos += 4 |
|
self._state = READ_PAYLOAD |
|
|
|
if self._state == READ_PAYLOAD: |
|
chunk_len = data_len - start_pos |
|
if self._payload_bytes_to_read >= chunk_len: |
|
f_end_pos = data_len |
|
self._payload_bytes_to_read -= chunk_len |
|
else: |
|
f_end_pos = start_pos + self._payload_bytes_to_read |
|
self._payload_bytes_to_read = 0 |
|
|
|
had_fragments = self._frame_payload_len |
|
self._frame_payload_len += f_end_pos - start_pos |
|
f_start_pos = start_pos |
|
start_pos = f_end_pos |
|
|
|
if self._payload_bytes_to_read != 0: |
|
|
|
|
|
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) |
|
break |
|
|
|
payload: Union[bytes, bytearray] |
|
if had_fragments: |
|
|
|
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) |
|
if self._has_mask: |
|
assert self._frame_mask is not None |
|
payload_bytearray = bytearray(b"".join(self._payload_fragments)) |
|
websocket_mask(self._frame_mask, payload_bytearray) |
|
payload = payload_bytearray |
|
else: |
|
payload = b"".join(self._payload_fragments) |
|
self._payload_fragments.clear() |
|
elif self._has_mask: |
|
assert self._frame_mask is not None |
|
payload_bytearray = data_cstr[f_start_pos:f_end_pos] |
|
if type(payload_bytearray) is not bytearray: |
|
|
|
|
|
|
|
payload_bytearray = bytearray(payload_bytearray) |
|
websocket_mask(self._frame_mask, payload_bytearray) |
|
payload = payload_bytearray |
|
else: |
|
payload = data_cstr[f_start_pos:f_end_pos] |
|
|
|
self._handle_frame( |
|
self._frame_fin, self._frame_opcode, payload, self._compressed |
|
) |
|
self._frame_payload_len = 0 |
|
self._state = READ_HEADER |
|
|
|
|
|
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b"" |
|
|