"""Flow Control. States: FORWARDING PAUSING New messages: pb.SenderMarkRequest writer -> sender (empty message) pb.StatusReportRequest sender -> writer (reports current sender progress) pb.SenderReadRequest writer -> sender (requests read of transaction log) Thresholds: Threshold_High_MaxOutstandingData - When above this, stop sending requests to sender Threshold_Mid_StartSendingReadRequests - When below this, start sending read requests Threshold_Low_RestartSendingData - When below this, start sending normal records State machine: FORWARDING -> PAUSED if should_pause There is too much work outstanding to the sender thread, after the current request lets stop sending data. PAUSING -> FORWARDING if should_unpause -> PAUSING if should_recover -> PAUSING if should_quiesce """ import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Optional from wandb.proto import wandb_internal_pb2 as pb from wandb.sdk.lib import fsm from .settings_static import SettingsStatic if TYPE_CHECKING: from wandb.proto.wandb_internal_pb2 import Record logger = logging.getLogger(__name__) # By default we will allow 400 MiB of requests in the sender queue # before falling back to the transaction log. DEFAULT_THRESHOLD = 128 * 1024 * 1024 # 128 MiB def _get_request_type(record: "Record") -> Optional[str]: record_type = record.WhichOneof("record_type") if record_type != "request": return None request_type = record.request.WhichOneof("request_type") return request_type def _is_control_record(record: "Record") -> bool: return record.control.flow_control def _is_local_non_control_record(record: "Record") -> bool: return record.control.local and not record.control.flow_control @dataclass class StateContext: last_forwarded_offset: int = 0 last_sent_offset: int = 0 last_written_offset: int = 0 class FlowControl: _fsm: fsm.FsmWithContext["Record", StateContext] def __init__( self, settings: SettingsStatic, forward_record: Callable[["Record"], None], write_record: Callable[["Record"], int], pause_marker: Callable[[], None], recover_records: Callable[[int, int], None], _threshold_bytes_high: int = 0, _threshold_bytes_mid: int = 0, _threshold_bytes_low: int = 0, ) -> None: # thresholds to define when to PAUSE, RESTART, FORWARDING if ( _threshold_bytes_high == 0 or _threshold_bytes_mid == 0 or _threshold_bytes_low == 0 ): threshold = settings.x_network_buffer or DEFAULT_THRESHOLD _threshold_bytes_high = threshold _threshold_bytes_mid = threshold // 2 _threshold_bytes_low = threshold // 4 assert _threshold_bytes_high > _threshold_bytes_mid > _threshold_bytes_low # FSM definition state_forwarding = StateForwarding( forward_record=forward_record, pause_marker=pause_marker, threshold_pause=_threshold_bytes_high, ) state_pausing = StatePausing( forward_record=forward_record, recover_records=recover_records, threshold_recover=_threshold_bytes_mid, threshold_forward=_threshold_bytes_low, ) self._fsm = fsm.FsmWithContext( states=[state_forwarding, state_pausing], table={ StateForwarding: [ fsm.FsmEntry( state_forwarding._should_pause, StatePausing, state_forwarding._pause, ), ], StatePausing: [ fsm.FsmEntry( state_pausing._should_unpause, StateForwarding, state_pausing._unpause, ), fsm.FsmEntry( state_pausing._should_recover, StatePausing, state_pausing._recover, ), fsm.FsmEntry( state_pausing._should_quiesce, StatePausing, state_pausing._quiesce, ), ], }, ) def flush(self) -> None: # TODO(mempressure): what do we do here, how do we make sure we dont have work in pause state pass def flow(self, record: "Record") -> None: self._fsm.input(record) class StateShared: _context: StateContext def __init__(self) -> None: self._context = StateContext() def _update_written_offset(self, record: "Record") -> None: end_offset = record.control.end_offset if end_offset: self._context.last_written_offset = end_offset def _update_forwarded_offset(self) -> None: self._context.last_forwarded_offset = self._context.last_written_offset def _process(self, record: "Record") -> None: request_type = _get_request_type(record) if not request_type: return process_str = f"_process_{request_type}" process_handler: Optional[Callable[[pb.Record], None]] = getattr( self, process_str, None ) if not process_handler: return process_handler(record) def _process_status_report(self, record: "Record") -> None: sent_offset = record.request.status_report.sent_offset self._context.last_sent_offset = sent_offset def on_exit(self, record: "Record") -> StateContext: return self._context def on_enter(self, record: "Record", context: StateContext) -> None: self._context = context @property def _behind_bytes(self) -> int: return self._context.last_forwarded_offset - self._context.last_sent_offset class StateForwarding(StateShared): _forward_record: Callable[["Record"], None] _pause_marker: Callable[[], None] _threshold_pause: int def __init__( self, forward_record: Callable[["Record"], None], pause_marker: Callable[[], None], threshold_pause: int, ) -> None: super().__init__() self._forward_record = forward_record self._pause_marker = pause_marker self._threshold_pause = threshold_pause def _should_pause(self, record: "Record") -> bool: return self._behind_bytes >= self._threshold_pause def _pause(self, record: "Record") -> None: self._pause_marker() def on_check(self, record: "Record") -> None: self._update_written_offset(record) self._process(record) if not _is_control_record(record): self._forward_record(record) self._update_forwarded_offset() class StatePausing(StateShared): _forward_record: Callable[["Record"], None] _recover_records: Callable[[int, int], None] _threshold_recover: int _threshold_forward: int def __init__( self, forward_record: Callable[["Record"], None], recover_records: Callable[[int, int], None], threshold_recover: int, threshold_forward: int, ) -> None: super().__init__() self._forward_record = forward_record self._recover_records = recover_records self._threshold_recover = threshold_recover self._threshold_forward = threshold_forward def _should_unpause(self, record: "Record") -> bool: return self._behind_bytes < self._threshold_forward def _unpause(self, record: "Record") -> None: self._quiesce(record) def _should_recover(self, record: "Record") -> bool: return self._behind_bytes < self._threshold_recover def _recover(self, record: "Record") -> None: self._quiesce(record) def _should_quiesce(self, record: "Record") -> bool: return _is_local_non_control_record(record) def _quiesce(self, record: "Record") -> None: start = self._context.last_forwarded_offset end = self._context.last_written_offset if start != end: self._recover_records(start, end) if _is_local_non_control_record(record): self._forward_record(record) self._update_forwarded_offset() def on_check(self, record: "Record") -> None: self._update_written_offset(record) self._process(record)