|
"""Router - handle message router (base class). |
|
|
|
Router to manage responses. |
|
|
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import logging |
|
import threading |
|
from abc import abstractmethod |
|
from typing import TYPE_CHECKING |
|
|
|
from wandb.proto import wandb_internal_pb2 as pb |
|
from wandb.proto import wandb_server_pb2 as spb |
|
from wandb.sdk import mailbox |
|
|
|
if TYPE_CHECKING: |
|
from queue import Queue |
|
|
|
|
|
logger = logging.getLogger("wandb") |
|
|
|
|
|
class MessageRouterClosedError(Exception): |
|
"""Router has been closed.""" |
|
|
|
|
|
class MessageRouter: |
|
_request_queue: Queue[pb.Record] |
|
_response_queue: Queue[pb.Result] |
|
_mailbox: mailbox.Mailbox | None |
|
|
|
def __init__(self, mailbox: mailbox.Mailbox | None = None) -> None: |
|
self._mailbox = mailbox |
|
self._lock = threading.Lock() |
|
|
|
self._join_event = threading.Event() |
|
self._thread = threading.Thread(target=self.message_loop) |
|
self._thread.name = "MsgRouterThr" |
|
self._thread.daemon = True |
|
self._thread.start() |
|
|
|
@abstractmethod |
|
def _read_message(self) -> pb.Result | spb.ServerResponse | None: |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def _send_message(self, record: pb.Record) -> None: |
|
raise NotImplementedError |
|
|
|
def message_loop(self) -> None: |
|
try: |
|
while not self._join_event.is_set(): |
|
try: |
|
msg = self._read_message() |
|
except EOFError: |
|
|
|
|
|
logger.warning("EOFError seen in message_loop") |
|
break |
|
except MessageRouterClosedError as e: |
|
logger.warning("message_loop has been closed", exc_info=e) |
|
break |
|
if not msg: |
|
continue |
|
self._handle_msg_rcv(msg) |
|
|
|
finally: |
|
if self._mailbox: |
|
self._mailbox.close() |
|
|
|
def join(self) -> None: |
|
self._join_event.set() |
|
self._thread.join() |
|
|
|
def _handle_msg_rcv(self, msg: pb.Result | spb.ServerResponse) -> None: |
|
if not self._mailbox: |
|
return |
|
|
|
if isinstance(msg, pb.Result) and msg.control.mailbox_slot: |
|
self._mailbox.deliver( |
|
spb.ServerResponse( |
|
request_id=msg.control.mailbox_slot, |
|
result_communicate=msg, |
|
) |
|
) |
|
elif isinstance(msg, spb.ServerResponse) and msg.request_id: |
|
self._mailbox.deliver(msg) |
|
|