jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""Router - handle message router (base class).
Router to manage responses.
"""
from __future__ import annotations
import logging
import threading
from abc import abstractmethod
from typing import TYPE_CHECKING
from wandb.proto import wandb_internal_pb2 as pb
from wandb.proto import wandb_server_pb2 as spb
from wandb.sdk import mailbox
if TYPE_CHECKING:
from queue import Queue
logger = logging.getLogger("wandb")
class MessageRouterClosedError(Exception):
"""Router has been closed."""
class MessageRouter:
_request_queue: Queue[pb.Record]
_response_queue: Queue[pb.Result]
_mailbox: mailbox.Mailbox | None
def __init__(self, mailbox: mailbox.Mailbox | None = None) -> None:
self._mailbox = mailbox
self._lock = threading.Lock()
self._join_event = threading.Event()
self._thread = threading.Thread(target=self.message_loop)
self._thread.name = "MsgRouterThr"
self._thread.daemon = True
self._thread.start()
@abstractmethod
def _read_message(self) -> pb.Result | spb.ServerResponse | None:
raise NotImplementedError
@abstractmethod
def _send_message(self, record: pb.Record) -> None:
raise NotImplementedError
def message_loop(self) -> None:
try:
while not self._join_event.is_set():
try:
msg = self._read_message()
except EOFError:
# On abnormal shutdown the queue will be destroyed underneath
# resulting in EOFError. message_loop needs to exit..
logger.warning("EOFError seen in message_loop")
break
except MessageRouterClosedError as e:
logger.warning("message_loop has been closed", exc_info=e)
break
if not msg:
continue
self._handle_msg_rcv(msg)
finally:
if self._mailbox:
self._mailbox.close()
def join(self) -> None:
self._join_event.set()
self._thread.join()
def _handle_msg_rcv(self, msg: pb.Result | spb.ServerResponse) -> None:
if not self._mailbox:
return
if isinstance(msg, pb.Result) and msg.control.mailbox_slot:
self._mailbox.deliver(
spb.ServerResponse(
request_id=msg.control.mailbox_slot,
result_communicate=msg,
)
)
elif isinstance(msg, spb.ServerResponse) and msg.request_id:
self._mailbox.deliver(msg)