"""streams: class that manages internal threads for each run. StreamThread: Thread that runs internal.wandb_internal() StreamRecord: All the external state for the internal thread (queues, etc) StreamAction: Lightweight record for stream ops for thread safety StreamMux: Container for dictionary of stream threads per runid """ from __future__ import annotations import asyncio import functools import queue import threading import time from threading import Event from typing import Any, Callable, NoReturn import psutil from wandb.proto import wandb_internal_pb2 as pb from wandb.sdk.interface.interface_relay import InterfaceRelay from wandb.sdk.interface.router_relay import MessageRelayRouter from wandb.sdk.internal.internal import wandb_internal from wandb.sdk.internal.settings_static import SettingsStatic from wandb.sdk.lib import asyncio_compat, progress from wandb.sdk.lib import printer as printerlib from wandb.sdk.mailbox import Mailbox, MailboxHandle, wait_all_with_progress from wandb.sdk.wandb_run import Run class StreamThread(threading.Thread): """Class to running internal process as a thread.""" def __init__(self, target: Callable, kwargs: dict[str, Any]) -> None: threading.Thread.__init__(self) self.name = "StreamThr" self._target = target self._kwargs = kwargs self.daemon = True def run(self) -> None: # TODO: catch exceptions and report errors to scheduler self._target(**self._kwargs) class StreamRecord: _record_q: queue.Queue[pb.Record] _result_q: queue.Queue[pb.Result] _relay_q: queue.Queue[pb.Result] _iface: InterfaceRelay _thread: StreamThread _settings: SettingsStatic _started: bool def __init__(self, settings: SettingsStatic) -> None: self._started = False self._mailbox = Mailbox() self._record_q = queue.Queue() self._result_q = queue.Queue() self._relay_q = queue.Queue() self._router = MessageRelayRouter( request_queue=self._record_q, response_queue=self._result_q, relay_queue=self._relay_q, mailbox=self._mailbox, ) self._iface = InterfaceRelay( record_q=self._record_q, result_q=self._result_q, relay_q=self._relay_q, mailbox=self._mailbox, ) self._settings = settings def start_thread(self, thread: StreamThread) -> None: self._thread = thread thread.start() self._wait_thread_active() def _wait_thread_active(self) -> None: self._iface.deliver_status().wait_or(timeout=None) def join(self) -> None: self._iface.join() self._router.join() if self._thread: self._thread.join() def drop(self) -> None: self._iface._drop = True @property def interface(self) -> InterfaceRelay: return self._iface def mark_started(self) -> None: self._started = True def update(self, settings: SettingsStatic) -> None: # Note: Currently just overriding the _settings attribute # once we use Settings Class we might want to properly update it self._settings = settings class StreamAction: _action: str _stream_id: str _processed: Event _data: Any def __init__(self, action: str, stream_id: str, data: Any | None = None): self._action = action self._stream_id = stream_id self._data = data self._processed = Event() def __repr__(self) -> str: return f"StreamAction({self._action},{self._stream_id})" def wait_handled(self) -> None: self._processed.wait() def set_handled(self) -> None: self._processed.set() @property def stream_id(self) -> str: return self._stream_id class StreamMux: _streams_lock: threading.Lock _streams: dict[str, StreamRecord] _port: int | None _pid: int | None _action_q: queue.Queue[StreamAction] _stopped: Event _pid_checked_ts: float | None def __init__(self) -> None: self._streams_lock = threading.Lock() self._streams = dict() self._port = None self._pid = None self._stopped = Event() self._action_q = queue.Queue() self._pid_checked_ts = None def _get_stopped_event(self) -> Event: # TODO: clean this up, there should be a better way to abstract this return self._stopped def set_port(self, port: int) -> None: self._port = port def set_pid(self, pid: int) -> None: self._pid = pid def add_stream(self, stream_id: str, settings: SettingsStatic) -> None: action = StreamAction(action="add", stream_id=stream_id, data=settings) self._action_q.put(action) action.wait_handled() def start_stream(self, stream_id: str) -> None: action = StreamAction(action="start", stream_id=stream_id) self._action_q.put(action) action.wait_handled() def update_stream(self, stream_id: str, settings: SettingsStatic) -> None: action = StreamAction(action="update", stream_id=stream_id, data=settings) self._action_q.put(action) action.wait_handled() def del_stream(self, stream_id: str) -> None: action = StreamAction(action="del", stream_id=stream_id) self._action_q.put(action) action.wait_handled() def drop_stream(self, stream_id: str) -> None: action = StreamAction(action="drop", stream_id=stream_id) self._action_q.put(action) action.wait_handled() def teardown(self, exit_code: int) -> None: action = StreamAction(action="teardown", stream_id="na", data=exit_code) self._action_q.put(action) action.wait_handled() def stream_names(self) -> list[str]: with self._streams_lock: names = list(self._streams.keys()) return names def has_stream(self, stream_id: str) -> bool: with self._streams_lock: return stream_id in self._streams def get_stream(self, stream_id: str) -> StreamRecord: """Returns the StreamRecord for the ID. Raises: KeyError: If a corresponding StreamRecord does not exist. """ with self._streams_lock: stream = self._streams[stream_id] return stream def _process_add(self, action: StreamAction) -> None: stream = StreamRecord(action._data) # run_id = action.stream_id # will want to fix if a streamid != runid settings = action._data thread = StreamThread( target=wandb_internal, kwargs=dict( settings=settings, record_q=stream._record_q, result_q=stream._result_q, port=self._port, user_pid=self._pid, ), ) stream.start_thread(thread) with self._streams_lock: self._streams[action._stream_id] = stream def _process_start(self, action: StreamAction) -> None: with self._streams_lock: self._streams[action._stream_id].mark_started() def _process_update(self, action: StreamAction) -> None: with self._streams_lock: self._streams[action._stream_id].update(action._data) def _process_del(self, action: StreamAction) -> None: with self._streams_lock: stream = self._streams.pop(action._stream_id) stream.join() # TODO: we assume stream has already been shutdown. should we verify? def _process_drop(self, action: StreamAction) -> None: with self._streams_lock: if action._stream_id in self._streams: stream = self._streams.pop(action._stream_id) stream.drop() stream.join() async def _finish_all_progress( self, progress_printer: progress.ProgressPrinter, streams_to_watch: dict[str, StreamRecord], ) -> None: """Poll the streams and display statistics about them. This never returns and must be cancelled. Args: progress_printer: Printer to use for displaying finish progress. streams_to_watch: Streams to poll for finish progress. """ results: dict[str, pb.Result | None] = {} async def loop_poll_stream( stream_id: str, stream: StreamRecord, ) -> NoReturn: while True: start_time = time.monotonic() handle = stream.interface.deliver_poll_exit() results[stream_id] = await handle.wait_async(timeout=None) elapsed_time = time.monotonic() - start_time if elapsed_time < 1: await asyncio.sleep(1 - elapsed_time) async def loop_update_printer() -> NoReturn: while True: poll_exit_responses: list[pb.PollExitResponse] = [] for result in results.values(): if not result or not result.response: continue if poll_exit_response := result.response.poll_exit_response: poll_exit_responses.append(poll_exit_response) progress_printer.update(poll_exit_responses) await asyncio.sleep(1) async with asyncio_compat.open_task_group() as task_group: for stream_id, stream in streams_to_watch.items(): task_group.start_soon(loop_poll_stream(stream_id, stream)) task_group.start_soon(loop_update_printer()) def _finish_all(self, streams: dict[str, StreamRecord], exit_code: int) -> None: if not streams: return printer = printerlib.new_printer() # fixme: for now we have a single printer for all streams, # and jupyter is disabled if at least single stream's setting set `_jupyter` to false exit_handles: list[MailboxHandle[pb.Result]] = [] # only finish started streams, non started streams failed early started_streams: dict[str, StreamRecord] = {} not_started_streams: dict[str, StreamRecord] = {} for stream_id, stream in streams.items(): d = started_streams if stream._started else not_started_streams d[stream_id] = stream for stream in started_streams.values(): handle = stream.interface.deliver_exit(exit_code) exit_handles.append(handle) with progress.progress_printer( printer, default_text="Finishing up...", ) as progress_printer: # todo: should we wait for the max timeout (?) of all exit handles or just wait forever? # timeout = max(stream._settings._exit_timeout for stream in streams.values()) wait_all_with_progress( exit_handles, timeout=None, progress_after=1, display_progress=functools.partial( self._finish_all_progress, progress_printer, started_streams, ), ) # These could be done in parallel in the future for _sid, stream in started_streams.items(): # dispatch all our final requests poll_exit_handle = stream.interface.deliver_poll_exit() final_summary_handle = stream.interface.deliver_get_summary() sampled_history_handle = stream.interface.deliver_request_sampled_history() internal_messages_handle = stream.interface.deliver_internal_messages() result = internal_messages_handle.wait_or(timeout=None) internal_messages_response = result.response.internal_messages_response result = poll_exit_handle.wait_or(timeout=None) poll_exit_response = result.response.poll_exit_response result = sampled_history_handle.wait_or(timeout=None) sampled_history = result.response.sampled_history_response result = final_summary_handle.wait_or(timeout=None) final_summary = result.response.get_summary_response Run._footer( sampled_history=sampled_history, final_summary=final_summary, poll_exit_response=poll_exit_response, internal_messages_response=internal_messages_response, settings=stream._settings, # type: ignore printer=printer, ) stream.join() # not started streams need to be cleaned up for stream in not_started_streams.values(): stream.join() def _process_teardown(self, action: StreamAction) -> None: exit_code: int = action._data with self._streams_lock: # TODO: mark streams to prevent new modifications? streams_copy = self._streams.copy() self._finish_all(streams_copy, exit_code) with self._streams_lock: self._streams = dict() self._stopped.set() def _process_action(self, action: StreamAction) -> None: if action._action == "add": self._process_add(action) return if action._action == "update": self._process_update(action) return if action._action == "start": self._process_start(action) return if action._action == "del": self._process_del(action) return if action._action == "drop": self._process_drop(action) return if action._action == "teardown": self._process_teardown(action) return raise AssertionError(f"Unsupported action: {action._action}") def _check_orphaned(self) -> bool: if not self._pid: return False time_now = time.time() # if we have checked already and it was less than 2 seconds ago if self._pid_checked_ts and time_now < self._pid_checked_ts + 2: return False self._pid_checked_ts = time_now return not psutil.pid_exists(self._pid) def _loop(self) -> None: while not self._stopped.is_set(): if self._check_orphaned(): # parent process is gone, let other threads know we need to shut down self._stopped.set() try: action = self._action_q.get(timeout=1) except queue.Empty: continue self._process_action(action) action.set_handled() self._action_q.task_done() self._action_q.join() def loop(self) -> None: self._loop() def cleanup(self) -> None: pass