|
"""Writer thread.""" |
|
|
|
import logging |
|
from typing import TYPE_CHECKING, Callable, Optional |
|
|
|
from wandb.proto import wandb_internal_pb2 as pb |
|
from wandb.proto import wandb_telemetry_pb2 as tpb |
|
|
|
from ..interface.interface_queue import InterfaceQueue |
|
from ..lib import proto_util, telemetry |
|
from . import context, datastore, flow_control |
|
from .settings_static import SettingsStatic |
|
|
|
if TYPE_CHECKING: |
|
from queue import Queue |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class WriteManager: |
|
_settings: SettingsStatic |
|
_record_q: "Queue[pb.Record]" |
|
_result_q: "Queue[pb.Result]" |
|
_sender_q: "Queue[pb.Record]" |
|
_interface: InterfaceQueue |
|
_context_keeper: context.ContextKeeper |
|
|
|
_ds: Optional[datastore.DataStore] |
|
_flow_control: Optional[flow_control.FlowControl] |
|
_status_report: Optional["pb.StatusReportRequest"] |
|
_record_num: int |
|
_telemetry_obj: tpb.TelemetryRecord |
|
_telemetry_overflow: bool |
|
_use_flow_control: bool |
|
|
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
settings: SettingsStatic, |
|
record_q: "Queue[pb.Record]", |
|
result_q: "Queue[pb.Result]", |
|
sender_q: "Queue[pb.Record]", |
|
interface: InterfaceQueue, |
|
context_keeper: context.ContextKeeper, |
|
): |
|
self._settings = settings |
|
self._record_q = record_q |
|
self._result_q = result_q |
|
self._sender_q = sender_q |
|
self._interface = interface |
|
self._context_keeper = context_keeper |
|
|
|
|
|
|
|
|
|
self._ds = None |
|
self._flow_control = None |
|
self._status_report = None |
|
self._record_num = 0 |
|
self._telemetry_obj = tpb.TelemetryRecord() |
|
self._telemetry_overflow = False |
|
self._use_flow_control = not ( |
|
self._settings.x_flow_control_disabled or self._settings._offline |
|
) |
|
|
|
def open(self) -> None: |
|
self._ds = datastore.DataStore() |
|
self._ds.open_for_write(self._settings.sync_file) |
|
self._flow_control = flow_control.FlowControl( |
|
settings=self._settings, |
|
write_record=self._write_record, |
|
forward_record=self._forward_record, |
|
pause_marker=self._pause_marker, |
|
recover_records=self._recover_records, |
|
) |
|
|
|
def _forward_record(self, record: "pb.Record") -> None: |
|
self._context_keeper.add_from_record(record) |
|
self._sender_q.put(record) |
|
|
|
def _send_mark(self) -> None: |
|
sender_mark = pb.SenderMarkRequest() |
|
record = self._interface._make_request(sender_mark=sender_mark) |
|
self._forward_record(record) |
|
|
|
def _maybe_send_telemetry(self) -> None: |
|
if self._telemetry_overflow: |
|
return |
|
self._telemetry_overflow = True |
|
with telemetry.context(obj=self._telemetry_obj) as tel: |
|
tel.feature.flow_control_overflow = True |
|
telemetry_record = pb.TelemetryRecordRequest(telemetry=self._telemetry_obj) |
|
record = self._interface._make_request(telemetry_record=telemetry_record) |
|
self._forward_record(record) |
|
|
|
def _pause_marker(self) -> None: |
|
self._maybe_send_telemetry() |
|
self._send_mark() |
|
|
|
def _write_record(self, record: "pb.Record") -> int: |
|
assert self._ds |
|
|
|
self._record_num += 1 |
|
proto_util._assign_record_num(record, self._record_num) |
|
ret = self._ds.write(record) |
|
assert ret is not None |
|
|
|
_start_offset, end_offset, _flush_offset = ret |
|
proto_util._assign_end_offset(record, end_offset) |
|
return end_offset |
|
|
|
def _ensure_flushed(self, offset: int) -> None: |
|
if self._ds: |
|
self._ds.ensure_flushed(offset) |
|
|
|
def _recover_records(self, start: int, end: int) -> None: |
|
sender_read = pb.SenderReadRequest(start_offset=start, final_offset=end) |
|
|
|
|
|
|
|
record = self._interface._make_request(sender_read=sender_read) |
|
self._ensure_flushed(end) |
|
self._forward_record(record) |
|
|
|
def _write(self, record: "pb.Record") -> None: |
|
if not self._ds: |
|
self.open() |
|
assert self._flow_control |
|
|
|
if not record.control.local: |
|
self._write_record(record) |
|
|
|
if self._use_flow_control: |
|
self._flow_control.flow(record) |
|
elif not self._settings._offline or record.control.always_send: |
|
|
|
|
|
|
|
|
|
|
|
self._forward_record(record) |
|
|
|
def write(self, record: "pb.Record") -> None: |
|
record_type = record.WhichOneof("record_type") |
|
assert record_type |
|
writer_str = "write_" + record_type |
|
write_handler: Callable[[pb.Record], None] = getattr( |
|
self, writer_str, self._write |
|
) |
|
write_handler(record) |
|
|
|
def write_request(self, record: "pb.Record") -> None: |
|
request_type = record.request.WhichOneof("request_type") |
|
assert request_type |
|
write_request_str = "write_request_" + request_type |
|
write_request_handler: Optional[Callable[[pb.Record], None]] = getattr( |
|
self, write_request_str, None |
|
) |
|
if write_request_handler: |
|
return write_request_handler(record) |
|
self._write(record) |
|
|
|
def write_request_run_status(self, record: "pb.Record") -> None: |
|
result = proto_util._result_from_record(record) |
|
if self._status_report: |
|
result.response.run_status_response.sync_time.CopyFrom( |
|
self._status_report.sync_time |
|
) |
|
send_record_num = self._status_report.record_num |
|
result.response.run_status_response.sync_items_total = self._record_num |
|
result.response.run_status_response.sync_items_pending = ( |
|
self._record_num - send_record_num |
|
) |
|
self._respond_result(result) |
|
|
|
def write_request_status_report(self, record: "pb.Record") -> None: |
|
self._status_report = record.request.status_report |
|
self._write(record) |
|
|
|
def write_request_cancel(self, record: "pb.Record") -> None: |
|
cancel_id = record.request.cancel.cancel_slot |
|
self._context_keeper.cancel(cancel_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _respond_result(self, result: "pb.Result") -> None: |
|
self._result_q.put(result) |
|
|
|
def finish(self) -> None: |
|
if self._flow_control: |
|
self._flow_control.flush() |
|
if self._ds: |
|
self._ds.close() |
|
|
|
|
|
|
|
def debounce(self) -> None: |
|
pass |
|
|