jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""Flow Control.
States:
FORWARDING
PAUSING
New messages:
pb.SenderMarkRequest writer -> sender (empty message)
pb.StatusReportRequest sender -> writer (reports current sender progress)
pb.SenderReadRequest writer -> sender (requests read of transaction log)
Thresholds:
Threshold_High_MaxOutstandingData - When above this, stop sending requests to sender
Threshold_Mid_StartSendingReadRequests - When below this, start sending read requests
Threshold_Low_RestartSendingData - When below this, start sending normal records
State machine:
FORWARDING
-> PAUSED if should_pause
There is too much work outstanding to the sender thread, after the current request
lets stop sending data.
PAUSING
-> FORWARDING if should_unpause
-> PAUSING if should_recover
-> PAUSING if should_quiesce
"""
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
from wandb.proto import wandb_internal_pb2 as pb
from wandb.sdk.lib import fsm
from .settings_static import SettingsStatic
if TYPE_CHECKING:
from wandb.proto.wandb_internal_pb2 import Record
logger = logging.getLogger(__name__)
# By default we will allow 400 MiB of requests in the sender queue
# before falling back to the transaction log.
DEFAULT_THRESHOLD = 128 * 1024 * 1024 # 128 MiB
def _get_request_type(record: "Record") -> Optional[str]:
record_type = record.WhichOneof("record_type")
if record_type != "request":
return None
request_type = record.request.WhichOneof("request_type")
return request_type
def _is_control_record(record: "Record") -> bool:
return record.control.flow_control
def _is_local_non_control_record(record: "Record") -> bool:
return record.control.local and not record.control.flow_control
@dataclass
class StateContext:
last_forwarded_offset: int = 0
last_sent_offset: int = 0
last_written_offset: int = 0
class FlowControl:
_fsm: fsm.FsmWithContext["Record", StateContext]
def __init__(
self,
settings: SettingsStatic,
forward_record: Callable[["Record"], None],
write_record: Callable[["Record"], int],
pause_marker: Callable[[], None],
recover_records: Callable[[int, int], None],
_threshold_bytes_high: int = 0,
_threshold_bytes_mid: int = 0,
_threshold_bytes_low: int = 0,
) -> None:
# thresholds to define when to PAUSE, RESTART, FORWARDING
if (
_threshold_bytes_high == 0
or _threshold_bytes_mid == 0
or _threshold_bytes_low == 0
):
threshold = settings.x_network_buffer or DEFAULT_THRESHOLD
_threshold_bytes_high = threshold
_threshold_bytes_mid = threshold // 2
_threshold_bytes_low = threshold // 4
assert _threshold_bytes_high > _threshold_bytes_mid > _threshold_bytes_low
# FSM definition
state_forwarding = StateForwarding(
forward_record=forward_record,
pause_marker=pause_marker,
threshold_pause=_threshold_bytes_high,
)
state_pausing = StatePausing(
forward_record=forward_record,
recover_records=recover_records,
threshold_recover=_threshold_bytes_mid,
threshold_forward=_threshold_bytes_low,
)
self._fsm = fsm.FsmWithContext(
states=[state_forwarding, state_pausing],
table={
StateForwarding: [
fsm.FsmEntry(
state_forwarding._should_pause,
StatePausing,
state_forwarding._pause,
),
],
StatePausing: [
fsm.FsmEntry(
state_pausing._should_unpause,
StateForwarding,
state_pausing._unpause,
),
fsm.FsmEntry(
state_pausing._should_recover,
StatePausing,
state_pausing._recover,
),
fsm.FsmEntry(
state_pausing._should_quiesce,
StatePausing,
state_pausing._quiesce,
),
],
},
)
def flush(self) -> None:
# TODO(mempressure): what do we do here, how do we make sure we dont have work in pause state
pass
def flow(self, record: "Record") -> None:
self._fsm.input(record)
class StateShared:
_context: StateContext
def __init__(self) -> None:
self._context = StateContext()
def _update_written_offset(self, record: "Record") -> None:
end_offset = record.control.end_offset
if end_offset:
self._context.last_written_offset = end_offset
def _update_forwarded_offset(self) -> None:
self._context.last_forwarded_offset = self._context.last_written_offset
def _process(self, record: "Record") -> None:
request_type = _get_request_type(record)
if not request_type:
return
process_str = f"_process_{request_type}"
process_handler: Optional[Callable[[pb.Record], None]] = getattr(
self, process_str, None
)
if not process_handler:
return
process_handler(record)
def _process_status_report(self, record: "Record") -> None:
sent_offset = record.request.status_report.sent_offset
self._context.last_sent_offset = sent_offset
def on_exit(self, record: "Record") -> StateContext:
return self._context
def on_enter(self, record: "Record", context: StateContext) -> None:
self._context = context
@property
def _behind_bytes(self) -> int:
return self._context.last_forwarded_offset - self._context.last_sent_offset
class StateForwarding(StateShared):
_forward_record: Callable[["Record"], None]
_pause_marker: Callable[[], None]
_threshold_pause: int
def __init__(
self,
forward_record: Callable[["Record"], None],
pause_marker: Callable[[], None],
threshold_pause: int,
) -> None:
super().__init__()
self._forward_record = forward_record
self._pause_marker = pause_marker
self._threshold_pause = threshold_pause
def _should_pause(self, record: "Record") -> bool:
return self._behind_bytes >= self._threshold_pause
def _pause(self, record: "Record") -> None:
self._pause_marker()
def on_check(self, record: "Record") -> None:
self._update_written_offset(record)
self._process(record)
if not _is_control_record(record):
self._forward_record(record)
self._update_forwarded_offset()
class StatePausing(StateShared):
_forward_record: Callable[["Record"], None]
_recover_records: Callable[[int, int], None]
_threshold_recover: int
_threshold_forward: int
def __init__(
self,
forward_record: Callable[["Record"], None],
recover_records: Callable[[int, int], None],
threshold_recover: int,
threshold_forward: int,
) -> None:
super().__init__()
self._forward_record = forward_record
self._recover_records = recover_records
self._threshold_recover = threshold_recover
self._threshold_forward = threshold_forward
def _should_unpause(self, record: "Record") -> bool:
return self._behind_bytes < self._threshold_forward
def _unpause(self, record: "Record") -> None:
self._quiesce(record)
def _should_recover(self, record: "Record") -> bool:
return self._behind_bytes < self._threshold_recover
def _recover(self, record: "Record") -> None:
self._quiesce(record)
def _should_quiesce(self, record: "Record") -> bool:
return _is_local_non_control_record(record)
def _quiesce(self, record: "Record") -> None:
start = self._context.last_forwarded_offset
end = self._context.last_written_offset
if start != end:
self._recover_records(start, end)
if _is_local_non_control_record(record):
self._forward_record(record)
self._update_forwarded_offset()
def on_check(self, record: "Record") -> None:
self._update_written_offset(record)
self._process(record)