import queue import socket import threading import time from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import wandb from wandb.proto import wandb_internal_pb2 as pb from wandb.proto import wandb_server_pb2 as spb from wandb.sdk.internal.settings_static import SettingsStatic from ..lib.sock_client import SockClient, SockClientClosedError from .streams import StreamMux if TYPE_CHECKING: from threading import Event from ..interface.interface_relay import InterfaceRelay class ClientDict: _client_dict: Dict[str, SockClient] _lock: threading.Lock def __init__(self) -> None: self._client_dict = {} self._lock = threading.Lock() def get_client(self, client_id: str) -> Optional[SockClient]: with self._lock: client = self._client_dict.get(client_id) return client def add_client(self, client: SockClient) -> None: with self._lock: self._client_dict[client._sockid] = client def del_client(self, client: SockClient) -> None: with self._lock: del self._client_dict[client._sockid] class SockServerInterfaceReaderThread(threading.Thread): _socket_client: SockClient _stopped: "Event" def __init__( self, clients: ClientDict, iface: "InterfaceRelay", stopped: "Event", ) -> None: self._iface = iface self._clients = clients threading.Thread.__init__(self) self.name = "SockSrvIntRdThr" self._stopped = stopped def run(self) -> None: while not self._stopped.is_set(): try: result = self._iface.relay_q.get(timeout=1) except queue.Empty: continue except OSError: # handle is closed break except ValueError: # queue is closed break sockid = result.control.relay_id assert sockid sock_client = self._clients.get_client(sockid) assert sock_client sresp = spb.ServerResponse() sresp.request_id = result.control.mailbox_slot sresp.result_communicate.CopyFrom(result) sock_client.send_server_response(sresp) class SockServerReadThread(threading.Thread): _sock_client: SockClient _mux: StreamMux _stopped: "Event" _clients: ClientDict def __init__( self, conn: socket.socket, mux: StreamMux, clients: ClientDict ) -> None: self._mux = mux threading.Thread.__init__(self) self.name = "SockSrvRdThr" sock_client = SockClient() sock_client.set_socket(conn) self._sock_client = sock_client self._stopped = mux._get_stopped_event() self._clients = clients def run(self) -> None: while not self._stopped.is_set(): try: sreq = self._sock_client.read_server_request() except SockClientClosedError: # socket has been closed # TODO: shut down other threads serving this socket? break assert sreq, "read_server_request should never timeout" sreq_type = sreq.WhichOneof("server_request_type") shandler_str = "server_" + sreq_type # type: ignore shandler: Callable[[spb.ServerRequest], None] = getattr( # type: ignore self, shandler_str, None ) assert shandler, f"unknown handle: {shandler_str}" # type: ignore shandler(sreq) def stop(self) -> None: try: # See shutdown notes in class SocketServer for a discussion about this mechanism self._sock_client.shutdown(socket.SHUT_RDWR) except OSError: pass self._sock_client.close() def server_inform_init(self, sreq: "spb.ServerRequest") -> None: request = sreq.inform_init stream_id = request._info.stream_id settings = SettingsStatic(request.settings) self._mux.add_stream(stream_id, settings=settings) iface = self._mux.get_stream(stream_id).interface self._clients.add_client(self._sock_client) iface_reader_thread = SockServerInterfaceReaderThread( clients=self._clients, iface=iface, stopped=self._stopped, ) iface_reader_thread.start() def server_inform_start(self, sreq: "spb.ServerRequest") -> None: request = sreq.inform_start stream_id = request._info.stream_id settings = SettingsStatic(request.settings) self._mux.update_stream(stream_id, settings=settings) self._mux.start_stream(stream_id) def server_inform_attach(self, sreq: "spb.ServerRequest") -> None: request = sreq.inform_attach stream_id = request._info.stream_id self._clients.add_client(self._sock_client) inform_attach_response = spb.ServerInformAttachResponse() inform_attach_response.settings.CopyFrom( self._mux._streams[stream_id]._settings._proto, ) response = spb.ServerResponse( request_id=sreq.request_id, inform_attach_response=inform_attach_response, ) self._sock_client.send_server_response(response) def server_record_communicate(self, sreq: "spb.ServerRequest") -> None: self._put_record(sreq.record_communicate) def server_record_publish(self, sreq: "spb.ServerRequest") -> None: self._put_record(sreq.record_publish) def _put_record(self, record: "pb.Record") -> None: # encode relay information so the right socket picks up the data record.control.relay_id = self._sock_client._sockid stream_id = record._info.stream_id try: iface = self._mux.get_stream(stream_id).interface except KeyError: # We should log the error but cannot because it may print to console # due to how logging is set up. This error usually happens if # a record is sent when no run is active, but during this time the # logger prints to the console. pass else: assert iface.record_q iface.record_q.put(record) def server_inform_finish(self, sreq: "spb.ServerRequest") -> None: request = sreq.inform_finish stream_id = request._info.stream_id self._mux.drop_stream(stream_id) def server_inform_teardown(self, sreq: "spb.ServerRequest") -> None: request = sreq.inform_teardown exit_code = request.exit_code self._mux.teardown(exit_code) class SockAcceptThread(threading.Thread): _sock: socket.socket _mux: StreamMux _stopped: "Event" _clients: ClientDict def __init__(self, sock: socket.socket, mux: StreamMux) -> None: self._sock = sock self._mux = mux self._stopped = mux._get_stopped_event() threading.Thread.__init__(self) self.name = "SockAcceptThr" self._clients = ClientDict() def run(self) -> None: read_threads = [] while not self._stopped.is_set(): try: conn, addr = self._sock.accept() except ConnectionAbortedError: break except OSError: # on shutdown break sr = SockServerReadThread(conn=conn, mux=self._mux, clients=self._clients) sr.start() read_threads.append(sr) for rt in read_threads: rt.stop() class DebugThread(threading.Thread): def __init__(self, mux: "StreamMux") -> None: threading.Thread.__init__(self) self.daemon = True self.name = "DebugThr" def run(self) -> None: while True: time.sleep(30) for thread in threading.enumerate(): wandb.termwarn(f"DEBUG: {thread.name}") class SocketServer: _mux: StreamMux _address: str _port: int _sock: socket.socket def __init__(self, mux: Any, address: str, port: int) -> None: self._mux = mux self._address = address self._port = port # This is the server socket that we accept new connections from self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) def _bind(self) -> None: self._sock.bind((self._address, self._port)) self._port = self._sock.getsockname()[1] @property def port(self) -> int: return self._port def start(self) -> None: self._bind() self._sock.listen(5) self._thread = SockAcceptThread(sock=self._sock, mux=self._mux) self._thread.start() # Note: Uncomment to figure out what thread is not exiting properly # self._dbg_thread = DebugThread(mux=self._mux) # self._dbg_thread.start() def stop(self) -> None: if self._sock: # we need to stop the SockAcceptThread try: # TODO(jhr): consider a more graceful shutdown in the future # socket.shutdown() is a more heavy handed approach to interrupting socket.accept() # in the future we might want to consider a more graceful shutdown which would involve setting # a threading Event and then initiating one last connection just to close down the thread # The advantage of the heavy handed approach is that it does not depend on the threads functioning # properly, that is, if something has gone wrong, we probably want to use this hammer to shut things down self._sock.shutdown(socket.SHUT_RDWR) except OSError: pass self._sock.close()