"""Batching file prepare requests to our API.""" import concurrent.futures import logging import queue import sys import threading from typing import ( TYPE_CHECKING, Callable, MutableMapping, MutableSequence, MutableSet, NamedTuple, Optional, Union, ) from wandb.errors.term import termerror from wandb.filesync import upload_job from wandb.sdk.lib.paths import LogicalPath if TYPE_CHECKING: from typing import TypedDict from wandb.filesync import stats from wandb.sdk.internal import file_stream, internal_api, progress from wandb.sdk.internal.settings_static import SettingsStatic class ArtifactStatus(TypedDict): finalize: bool pending_count: int commit_requested: bool pre_commit_callbacks: MutableSet["PreCommitFn"] result_futures: MutableSet["concurrent.futures.Future[None]"] PreCommitFn = Callable[[], None] OnRequestFinishFn = Callable[[], None] SaveFn = Callable[["progress.ProgressFn"], bool] logger = logging.getLogger(__name__) class RequestUpload(NamedTuple): path: str save_name: LogicalPath artifact_id: Optional[str] md5: Optional[str] copied: bool save_fn: Optional[SaveFn] digest: Optional[str] class RequestCommitArtifact(NamedTuple): artifact_id: str finalize: bool before_commit: PreCommitFn result_future: "concurrent.futures.Future[None]" class RequestFinish(NamedTuple): callback: Optional[OnRequestFinishFn] class EventJobDone(NamedTuple): job: RequestUpload exc: Optional[BaseException] Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone] class StepUpload: def __init__( self, api: "internal_api.Api", stats: "stats.Stats", event_queue: "queue.Queue[Event]", max_threads: int, file_stream: "file_stream.FileStreamApi", settings: Optional["SettingsStatic"] = None, ) -> None: self._api = api self._stats = stats self._event_queue = event_queue self._file_stream = file_stream self._thread = threading.Thread(target=self._thread_body) self._thread.daemon = True self._pool = concurrent.futures.ThreadPoolExecutor( thread_name_prefix="wandb-upload", max_workers=max_threads, ) # Indexed by files' `save_name`'s, which are their ID's in the Run. self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {} self._pending_jobs: MutableSequence[RequestUpload] = [] self._artifacts: MutableMapping[str, ArtifactStatus] = {} self.silent = bool(settings.silent) if settings else False def _thread_body(self) -> None: event: Optional[Event] # Wait for event in the queue, and process one by one until a # finish event is received finish_callback = None while True: event = self._event_queue.get() if isinstance(event, RequestFinish): finish_callback = event.callback break self._handle_event(event) # We've received a finish event. At this point, further Upload requests # are invalid. # After a finish event is received, iterate through the event queue # one by one and process all remaining events. while True: try: event = self._event_queue.get(True, 0.2) except queue.Empty: event = None if event: self._handle_event(event) elif not self._running_jobs: # Queue was empty and no jobs left. self._pool.shutdown(wait=False) if finish_callback: finish_callback() break def _handle_event(self, event: Event) -> None: if isinstance(event, EventJobDone): job = event.job if event.exc is not None: logger.exception( "Failed to upload file: %s", job.path, exc_info=event.exc ) if job.artifact_id: if event.exc is None: self._artifacts[job.artifact_id]["pending_count"] -= 1 self._maybe_commit_artifact(job.artifact_id) else: if not self.silent: termerror( "Uploading artifact file failed. Artifact won't be committed." ) self._fail_artifact_futures(job.artifact_id, event.exc) self._running_jobs.pop(job.save_name) # If we have any pending jobs, start one now if self._pending_jobs: event = self._pending_jobs.pop(0) self._start_upload_job(event) elif isinstance(event, RequestCommitArtifact): if event.artifact_id not in self._artifacts: self._init_artifact(event.artifact_id) self._artifacts[event.artifact_id]["commit_requested"] = True self._artifacts[event.artifact_id]["finalize"] = event.finalize self._artifacts[event.artifact_id]["pre_commit_callbacks"].add( event.before_commit ) self._artifacts[event.artifact_id]["result_futures"].add( event.result_future ) self._maybe_commit_artifact(event.artifact_id) elif isinstance(event, RequestUpload): if event.artifact_id is not None: if event.artifact_id not in self._artifacts: self._init_artifact(event.artifact_id) self._artifacts[event.artifact_id]["pending_count"] += 1 self._start_upload_job(event) else: raise TypeError(f"Event has unexpected type: {event!s}") def _start_upload_job(self, event: RequestUpload) -> None: # Operations on a single backend file must be serialized. if # we're already uploading this file, put the event on the # end of the queue if event.save_name in self._running_jobs: self._pending_jobs.append(event) return self._spawn_upload(event) def _spawn_upload(self, event: RequestUpload) -> None: """Spawn an upload job, and handles the bookkeeping of `self._running_jobs`. Context: it's important that, whenever we add an entry to `self._running_jobs`, we ensure that a corresponding `EventJobDone` message will eventually get handled; otherwise, the `_running_jobs` entry will never get removed, and the StepUpload will never shut down. The sole purpose of this function is to make sure that the code that adds an entry to `self._running_jobs` is textually right next to the code that eventually enqueues the `EventJobDone` message. This should help keep them in sync. """ # Adding the entry to `self._running_jobs` MUST happen in the main thread, # NOT in the job that gets submitted to the thread-pool, to guard against # this sequence of events: # - StepUpload receives a RequestUpload # ...and therefore spawns a thread to do the upload # - StepUpload receives a RequestFinish # ...and checks `self._running_jobs` to see if there are any tasks to wait for... # ...and there are none, because the addition to `self._running_jobs` happens in # the background thread, which the scheduler hasn't yet run... # ...so the StepUpload shuts down. Even though we haven't uploaded the file! # # This would be very bad! # So, this line has to happen _outside_ the `pool.submit()`. self._running_jobs[event.save_name] = event def run_and_notify() -> None: try: self._do_upload(event) finally: self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1])) self._pool.submit(run_and_notify) def _do_upload(self, event: RequestUpload) -> None: job = upload_job.UploadJob( self._stats, self._api, self._file_stream, self.silent, event.save_name, event.path, event.artifact_id, event.md5, event.copied, event.save_fn, event.digest, ) job.run() def _init_artifact(self, artifact_id: str) -> None: self._artifacts[artifact_id] = { "finalize": False, "pending_count": 0, "commit_requested": False, "pre_commit_callbacks": set(), "result_futures": set(), } def _maybe_commit_artifact(self, artifact_id: str) -> None: artifact_status = self._artifacts[artifact_id] if ( artifact_status["pending_count"] == 0 and artifact_status["commit_requested"] ): try: for pre_callback in artifact_status["pre_commit_callbacks"]: pre_callback() if artifact_status["finalize"]: self._api.commit_artifact(artifact_id) except Exception as exc: termerror( f"Committing artifact failed. Artifact {artifact_id} won't be finalized." ) termerror(str(exc)) self._fail_artifact_futures(artifact_id, exc) else: self._resolve_artifact_futures(artifact_id) def _fail_artifact_futures(self, artifact_id: str, exc: BaseException) -> None: futures = self._artifacts[artifact_id]["result_futures"] for result_future in futures: result_future.set_exception(exc) futures.clear() def _resolve_artifact_futures(self, artifact_id: str) -> None: futures = self._artifacts[artifact_id]["result_futures"] for result_future in futures: result_future.set_result(None) futures.clear() def start(self) -> None: self._thread.start() def is_alive(self) -> bool: return self._thread.is_alive()