File size: 2,630 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Router - handle message router (base class).

Router to manage responses.

"""

from __future__ import annotations

import logging
import threading
from abc import abstractmethod
from typing import TYPE_CHECKING

from wandb.proto import wandb_internal_pb2 as pb
from wandb.proto import wandb_server_pb2 as spb
from wandb.sdk import mailbox

if TYPE_CHECKING:
    from queue import Queue


logger = logging.getLogger("wandb")


class MessageRouterClosedError(Exception):
    """Router has been closed."""


class MessageRouter:
    _request_queue: Queue[pb.Record]
    _response_queue: Queue[pb.Result]
    _mailbox: mailbox.Mailbox | None

    def __init__(self, mailbox: mailbox.Mailbox | None = None) -> None:
        self._mailbox = mailbox
        self._lock = threading.Lock()

        self._join_event = threading.Event()
        self._thread = threading.Thread(target=self.message_loop)
        self._thread.name = "MsgRouterThr"
        self._thread.daemon = True
        self._thread.start()

    @abstractmethod
    def _read_message(self) -> pb.Result | spb.ServerResponse | None:
        raise NotImplementedError

    @abstractmethod
    def _send_message(self, record: pb.Record) -> None:
        raise NotImplementedError

    def message_loop(self) -> None:
        try:
            while not self._join_event.is_set():
                try:
                    msg = self._read_message()
                except EOFError:
                    # On abnormal shutdown the queue will be destroyed underneath
                    # resulting in EOFError.  message_loop needs to exit..
                    logger.warning("EOFError seen in message_loop")
                    break
                except MessageRouterClosedError as e:
                    logger.warning("message_loop has been closed", exc_info=e)
                    break
                if not msg:
                    continue
                self._handle_msg_rcv(msg)

        finally:
            if self._mailbox:
                self._mailbox.close()

    def join(self) -> None:
        self._join_event.set()
        self._thread.join()

    def _handle_msg_rcv(self, msg: pb.Result | spb.ServerResponse) -> None:
        if not self._mailbox:
            return

        if isinstance(msg, pb.Result) and msg.control.mailbox_slot:
            self._mailbox.deliver(
                spb.ServerResponse(
                    request_id=msg.control.mailbox_slot,
                    result_communicate=msg,
                )
            )
        elif isinstance(msg, spb.ServerResponse) and msg.request_id:
            self._mailbox.deliver(msg)