|
import queue |
|
import socket |
|
import threading |
|
import time |
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional |
|
|
|
import wandb |
|
from wandb.proto import wandb_internal_pb2 as pb |
|
from wandb.proto import wandb_server_pb2 as spb |
|
from wandb.sdk.internal.settings_static import SettingsStatic |
|
|
|
from ..lib.sock_client import SockClient, SockClientClosedError |
|
from .streams import StreamMux |
|
|
|
if TYPE_CHECKING: |
|
from threading import Event |
|
|
|
from ..interface.interface_relay import InterfaceRelay |
|
|
|
|
|
class ClientDict: |
|
_client_dict: Dict[str, SockClient] |
|
_lock: threading.Lock |
|
|
|
def __init__(self) -> None: |
|
self._client_dict = {} |
|
self._lock = threading.Lock() |
|
|
|
def get_client(self, client_id: str) -> Optional[SockClient]: |
|
with self._lock: |
|
client = self._client_dict.get(client_id) |
|
return client |
|
|
|
def add_client(self, client: SockClient) -> None: |
|
with self._lock: |
|
self._client_dict[client._sockid] = client |
|
|
|
def del_client(self, client: SockClient) -> None: |
|
with self._lock: |
|
del self._client_dict[client._sockid] |
|
|
|
|
|
class SockServerInterfaceReaderThread(threading.Thread): |
|
_socket_client: SockClient |
|
_stopped: "Event" |
|
|
|
def __init__( |
|
self, |
|
clients: ClientDict, |
|
iface: "InterfaceRelay", |
|
stopped: "Event", |
|
) -> None: |
|
self._iface = iface |
|
self._clients = clients |
|
threading.Thread.__init__(self) |
|
self.name = "SockSrvIntRdThr" |
|
self._stopped = stopped |
|
|
|
def run(self) -> None: |
|
while not self._stopped.is_set(): |
|
try: |
|
result = self._iface.relay_q.get(timeout=1) |
|
except queue.Empty: |
|
continue |
|
except OSError: |
|
|
|
break |
|
except ValueError: |
|
|
|
break |
|
sockid = result.control.relay_id |
|
assert sockid |
|
sock_client = self._clients.get_client(sockid) |
|
assert sock_client |
|
sresp = spb.ServerResponse() |
|
sresp.request_id = result.control.mailbox_slot |
|
sresp.result_communicate.CopyFrom(result) |
|
sock_client.send_server_response(sresp) |
|
|
|
|
|
class SockServerReadThread(threading.Thread): |
|
_sock_client: SockClient |
|
_mux: StreamMux |
|
_stopped: "Event" |
|
_clients: ClientDict |
|
|
|
def __init__( |
|
self, conn: socket.socket, mux: StreamMux, clients: ClientDict |
|
) -> None: |
|
self._mux = mux |
|
threading.Thread.__init__(self) |
|
self.name = "SockSrvRdThr" |
|
sock_client = SockClient() |
|
sock_client.set_socket(conn) |
|
self._sock_client = sock_client |
|
self._stopped = mux._get_stopped_event() |
|
self._clients = clients |
|
|
|
def run(self) -> None: |
|
while not self._stopped.is_set(): |
|
try: |
|
sreq = self._sock_client.read_server_request() |
|
except SockClientClosedError: |
|
|
|
|
|
break |
|
assert sreq, "read_server_request should never timeout" |
|
sreq_type = sreq.WhichOneof("server_request_type") |
|
shandler_str = "server_" + sreq_type |
|
shandler: Callable[[spb.ServerRequest], None] = getattr( |
|
self, shandler_str, None |
|
) |
|
assert shandler, f"unknown handle: {shandler_str}" |
|
shandler(sreq) |
|
|
|
def stop(self) -> None: |
|
try: |
|
|
|
self._sock_client.shutdown(socket.SHUT_RDWR) |
|
except OSError: |
|
pass |
|
self._sock_client.close() |
|
|
|
def server_inform_init(self, sreq: "spb.ServerRequest") -> None: |
|
request = sreq.inform_init |
|
stream_id = request._info.stream_id |
|
settings = SettingsStatic(request.settings) |
|
self._mux.add_stream(stream_id, settings=settings) |
|
|
|
iface = self._mux.get_stream(stream_id).interface |
|
self._clients.add_client(self._sock_client) |
|
iface_reader_thread = SockServerInterfaceReaderThread( |
|
clients=self._clients, |
|
iface=iface, |
|
stopped=self._stopped, |
|
) |
|
iface_reader_thread.start() |
|
|
|
def server_inform_start(self, sreq: "spb.ServerRequest") -> None: |
|
request = sreq.inform_start |
|
stream_id = request._info.stream_id |
|
settings = SettingsStatic(request.settings) |
|
self._mux.update_stream(stream_id, settings=settings) |
|
self._mux.start_stream(stream_id) |
|
|
|
def server_inform_attach(self, sreq: "spb.ServerRequest") -> None: |
|
request = sreq.inform_attach |
|
stream_id = request._info.stream_id |
|
|
|
self._clients.add_client(self._sock_client) |
|
inform_attach_response = spb.ServerInformAttachResponse() |
|
inform_attach_response.settings.CopyFrom( |
|
self._mux._streams[stream_id]._settings._proto, |
|
) |
|
response = spb.ServerResponse( |
|
request_id=sreq.request_id, |
|
inform_attach_response=inform_attach_response, |
|
) |
|
self._sock_client.send_server_response(response) |
|
|
|
def server_record_communicate(self, sreq: "spb.ServerRequest") -> None: |
|
self._put_record(sreq.record_communicate) |
|
|
|
def server_record_publish(self, sreq: "spb.ServerRequest") -> None: |
|
self._put_record(sreq.record_publish) |
|
|
|
def _put_record(self, record: "pb.Record") -> None: |
|
|
|
record.control.relay_id = self._sock_client._sockid |
|
stream_id = record._info.stream_id |
|
|
|
try: |
|
iface = self._mux.get_stream(stream_id).interface |
|
|
|
except KeyError: |
|
|
|
|
|
|
|
|
|
pass |
|
|
|
else: |
|
assert iface.record_q |
|
iface.record_q.put(record) |
|
|
|
def server_inform_finish(self, sreq: "spb.ServerRequest") -> None: |
|
request = sreq.inform_finish |
|
stream_id = request._info.stream_id |
|
self._mux.drop_stream(stream_id) |
|
|
|
def server_inform_teardown(self, sreq: "spb.ServerRequest") -> None: |
|
request = sreq.inform_teardown |
|
exit_code = request.exit_code |
|
self._mux.teardown(exit_code) |
|
|
|
|
|
class SockAcceptThread(threading.Thread): |
|
_sock: socket.socket |
|
_mux: StreamMux |
|
_stopped: "Event" |
|
_clients: ClientDict |
|
|
|
def __init__(self, sock: socket.socket, mux: StreamMux) -> None: |
|
self._sock = sock |
|
self._mux = mux |
|
self._stopped = mux._get_stopped_event() |
|
threading.Thread.__init__(self) |
|
self.name = "SockAcceptThr" |
|
self._clients = ClientDict() |
|
|
|
def run(self) -> None: |
|
read_threads = [] |
|
|
|
while not self._stopped.is_set(): |
|
try: |
|
conn, addr = self._sock.accept() |
|
except ConnectionAbortedError: |
|
break |
|
except OSError: |
|
|
|
break |
|
sr = SockServerReadThread(conn=conn, mux=self._mux, clients=self._clients) |
|
sr.start() |
|
read_threads.append(sr) |
|
|
|
for rt in read_threads: |
|
rt.stop() |
|
|
|
|
|
class DebugThread(threading.Thread): |
|
def __init__(self, mux: "StreamMux") -> None: |
|
threading.Thread.__init__(self) |
|
self.daemon = True |
|
self.name = "DebugThr" |
|
|
|
def run(self) -> None: |
|
while True: |
|
time.sleep(30) |
|
for thread in threading.enumerate(): |
|
wandb.termwarn(f"DEBUG: {thread.name}") |
|
|
|
|
|
class SocketServer: |
|
_mux: StreamMux |
|
_address: str |
|
_port: int |
|
_sock: socket.socket |
|
|
|
def __init__(self, mux: Any, address: str, port: int) -> None: |
|
self._mux = mux |
|
self._address = address |
|
self._port = port |
|
|
|
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
|
def _bind(self) -> None: |
|
self._sock.bind((self._address, self._port)) |
|
self._port = self._sock.getsockname()[1] |
|
|
|
@property |
|
def port(self) -> int: |
|
return self._port |
|
|
|
def start(self) -> None: |
|
self._bind() |
|
self._sock.listen(5) |
|
self._thread = SockAcceptThread(sock=self._sock, mux=self._mux) |
|
self._thread.start() |
|
|
|
|
|
|
|
|
|
def stop(self) -> None: |
|
if self._sock: |
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
self._sock.shutdown(socket.SHUT_RDWR) |
|
except OSError: |
|
pass |
|
self._sock.close() |
|
|