File size: 9,802 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import queue
import socket
import threading
import time
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional

import wandb
from wandb.proto import wandb_internal_pb2 as pb
from wandb.proto import wandb_server_pb2 as spb
from wandb.sdk.internal.settings_static import SettingsStatic

from ..lib.sock_client import SockClient, SockClientClosedError
from .streams import StreamMux

if TYPE_CHECKING:
    from threading import Event

    from ..interface.interface_relay import InterfaceRelay


class ClientDict:
    _client_dict: Dict[str, SockClient]
    _lock: threading.Lock

    def __init__(self) -> None:
        self._client_dict = {}
        self._lock = threading.Lock()

    def get_client(self, client_id: str) -> Optional[SockClient]:
        with self._lock:
            client = self._client_dict.get(client_id)
        return client

    def add_client(self, client: SockClient) -> None:
        with self._lock:
            self._client_dict[client._sockid] = client

    def del_client(self, client: SockClient) -> None:
        with self._lock:
            del self._client_dict[client._sockid]


class SockServerInterfaceReaderThread(threading.Thread):
    _socket_client: SockClient
    _stopped: "Event"

    def __init__(
        self,
        clients: ClientDict,
        iface: "InterfaceRelay",
        stopped: "Event",
    ) -> None:
        self._iface = iface
        self._clients = clients
        threading.Thread.__init__(self)
        self.name = "SockSrvIntRdThr"
        self._stopped = stopped

    def run(self) -> None:
        while not self._stopped.is_set():
            try:
                result = self._iface.relay_q.get(timeout=1)
            except queue.Empty:
                continue
            except OSError:
                # handle is closed
                break
            except ValueError:
                # queue is closed
                break
            sockid = result.control.relay_id
            assert sockid
            sock_client = self._clients.get_client(sockid)
            assert sock_client
            sresp = spb.ServerResponse()
            sresp.request_id = result.control.mailbox_slot
            sresp.result_communicate.CopyFrom(result)
            sock_client.send_server_response(sresp)


class SockServerReadThread(threading.Thread):
    _sock_client: SockClient
    _mux: StreamMux
    _stopped: "Event"
    _clients: ClientDict

    def __init__(
        self, conn: socket.socket, mux: StreamMux, clients: ClientDict
    ) -> None:
        self._mux = mux
        threading.Thread.__init__(self)
        self.name = "SockSrvRdThr"
        sock_client = SockClient()
        sock_client.set_socket(conn)
        self._sock_client = sock_client
        self._stopped = mux._get_stopped_event()
        self._clients = clients

    def run(self) -> None:
        while not self._stopped.is_set():
            try:
                sreq = self._sock_client.read_server_request()
            except SockClientClosedError:
                # socket has been closed
                # TODO: shut down other threads serving this socket?
                break
            assert sreq, "read_server_request should never timeout"
            sreq_type = sreq.WhichOneof("server_request_type")
            shandler_str = "server_" + sreq_type  # type: ignore
            shandler: Callable[[spb.ServerRequest], None] = getattr(  # type: ignore
                self, shandler_str, None
            )
            assert shandler, f"unknown handle: {shandler_str}"  # type: ignore
            shandler(sreq)

    def stop(self) -> None:
        try:
            # See shutdown notes in class SocketServer for a discussion about this mechanism
            self._sock_client.shutdown(socket.SHUT_RDWR)
        except OSError:
            pass
        self._sock_client.close()

    def server_inform_init(self, sreq: "spb.ServerRequest") -> None:
        request = sreq.inform_init
        stream_id = request._info.stream_id
        settings = SettingsStatic(request.settings)
        self._mux.add_stream(stream_id, settings=settings)

        iface = self._mux.get_stream(stream_id).interface
        self._clients.add_client(self._sock_client)
        iface_reader_thread = SockServerInterfaceReaderThread(
            clients=self._clients,
            iface=iface,
            stopped=self._stopped,
        )
        iface_reader_thread.start()

    def server_inform_start(self, sreq: "spb.ServerRequest") -> None:
        request = sreq.inform_start
        stream_id = request._info.stream_id
        settings = SettingsStatic(request.settings)
        self._mux.update_stream(stream_id, settings=settings)
        self._mux.start_stream(stream_id)

    def server_inform_attach(self, sreq: "spb.ServerRequest") -> None:
        request = sreq.inform_attach
        stream_id = request._info.stream_id

        self._clients.add_client(self._sock_client)
        inform_attach_response = spb.ServerInformAttachResponse()
        inform_attach_response.settings.CopyFrom(
            self._mux._streams[stream_id]._settings._proto,
        )
        response = spb.ServerResponse(
            request_id=sreq.request_id,
            inform_attach_response=inform_attach_response,
        )
        self._sock_client.send_server_response(response)

    def server_record_communicate(self, sreq: "spb.ServerRequest") -> None:
        self._put_record(sreq.record_communicate)

    def server_record_publish(self, sreq: "spb.ServerRequest") -> None:
        self._put_record(sreq.record_publish)

    def _put_record(self, record: "pb.Record") -> None:
        # encode relay information so the right socket picks up the data
        record.control.relay_id = self._sock_client._sockid
        stream_id = record._info.stream_id

        try:
            iface = self._mux.get_stream(stream_id).interface

        except KeyError:
            # We should log the error but cannot because it may print to console
            # due to how logging is set up. This error usually happens if
            # a record is sent when no run is active, but during this time the
            # logger prints to the console.
            pass

        else:
            assert iface.record_q
            iface.record_q.put(record)

    def server_inform_finish(self, sreq: "spb.ServerRequest") -> None:
        request = sreq.inform_finish
        stream_id = request._info.stream_id
        self._mux.drop_stream(stream_id)

    def server_inform_teardown(self, sreq: "spb.ServerRequest") -> None:
        request = sreq.inform_teardown
        exit_code = request.exit_code
        self._mux.teardown(exit_code)


class SockAcceptThread(threading.Thread):
    _sock: socket.socket
    _mux: StreamMux
    _stopped: "Event"
    _clients: ClientDict

    def __init__(self, sock: socket.socket, mux: StreamMux) -> None:
        self._sock = sock
        self._mux = mux
        self._stopped = mux._get_stopped_event()
        threading.Thread.__init__(self)
        self.name = "SockAcceptThr"
        self._clients = ClientDict()

    def run(self) -> None:
        read_threads = []

        while not self._stopped.is_set():
            try:
                conn, addr = self._sock.accept()
            except ConnectionAbortedError:
                break
            except OSError:
                # on shutdown
                break
            sr = SockServerReadThread(conn=conn, mux=self._mux, clients=self._clients)
            sr.start()
            read_threads.append(sr)

        for rt in read_threads:
            rt.stop()


class DebugThread(threading.Thread):
    def __init__(self, mux: "StreamMux") -> None:
        threading.Thread.__init__(self)
        self.daemon = True
        self.name = "DebugThr"

    def run(self) -> None:
        while True:
            time.sleep(30)
            for thread in threading.enumerate():
                wandb.termwarn(f"DEBUG: {thread.name}")


class SocketServer:
    _mux: StreamMux
    _address: str
    _port: int
    _sock: socket.socket

    def __init__(self, mux: Any, address: str, port: int) -> None:
        self._mux = mux
        self._address = address
        self._port = port
        # This is the server socket that we accept new connections from
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    def _bind(self) -> None:
        self._sock.bind((self._address, self._port))
        self._port = self._sock.getsockname()[1]

    @property
    def port(self) -> int:
        return self._port

    def start(self) -> None:
        self._bind()
        self._sock.listen(5)
        self._thread = SockAcceptThread(sock=self._sock, mux=self._mux)
        self._thread.start()
        # Note: Uncomment to figure out what thread is not exiting properly
        # self._dbg_thread = DebugThread(mux=self._mux)
        # self._dbg_thread.start()

    def stop(self) -> None:
        if self._sock:
            # we need to stop the SockAcceptThread
            try:
                # TODO(jhr): consider a more graceful shutdown in the future
                # socket.shutdown() is a more heavy handed approach to interrupting socket.accept()
                # in the future we might want to consider a more graceful shutdown which would involve setting
                # a threading Event and then initiating one last connection just to close down the thread
                # The advantage of the heavy handed approach is that it does not depend on the threads functioning
                # properly, that is, if something has gone wrong, we probably want to use this hammer to shut things down
                self._sock.shutdown(socket.SHUT_RDWR)
            except OSError:
                pass
            self._sock.close()