"""Artifact saver.""" from __future__ import annotations import concurrent.futures import json import os import tempfile from typing import TYPE_CHECKING, Awaitable, Sequence import wandb import wandb.filesync.step_prepare from wandb import util from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64 from wandb.sdk.lib.paths import URIStr if TYPE_CHECKING: from typing import Protocol from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry from wandb.sdk.internal.file_pusher import FilePusher from wandb.sdk.internal.internal_api import Api as InternalApi from wandb.sdk.internal.progress import ProgressFn class SaveFn(Protocol): def __call__( self, entry: ArtifactManifestEntry, progress_callback: ProgressFn ) -> bool: pass class SaveFnAsync(Protocol): def __call__( self, entry: ArtifactManifestEntry, progress_callback: ProgressFn ) -> Awaitable[bool]: pass class ArtifactSaver: _server_artifact: dict | None # TODO better define this dict def __init__( self, api: InternalApi, digest: str, manifest_json: dict, file_pusher: FilePusher, is_user_created: bool = False, ) -> None: self._api = api self._file_pusher = file_pusher self._digest = digest self._manifest = ArtifactManifest.from_manifest_json( manifest_json, api=self._api, ) self._is_user_created = is_user_created self._server_artifact = None def save( self, entity: str, project: str, type: str, name: str, client_id: str, sequence_client_id: str, distributed_id: str | None = None, finalize: bool = True, metadata: dict | None = None, ttl_duration_seconds: int | None = None, description: str | None = None, aliases: Sequence[str] | None = None, tags: Sequence[str] | None = None, use_after_commit: bool = False, incremental: bool = False, history_step: int | None = None, base_id: str | None = None, ) -> dict | None: return self._save_internal( entity, project, type, name, client_id, sequence_client_id, distributed_id, finalize, metadata, ttl_duration_seconds, description, aliases, tags, use_after_commit, incremental, history_step, base_id, ) def _save_internal( self, entity: str, project: str, type: str, name: str, client_id: str, sequence_client_id: str, distributed_id: str | None = None, finalize: bool = True, metadata: dict | None = None, ttl_duration_seconds: int | None = None, description: str | None = None, aliases: Sequence[str] | None = None, tags: Sequence[str] | None = None, use_after_commit: bool = False, incremental: bool = False, history_step: int | None = None, base_id: str | None = None, ) -> dict | None: alias_specs = [] for alias in aliases or []: alias_specs.append({"artifactCollectionName": name, "alias": alias}) tag_specs = [{"tagName": tag} for tag in tags or []] """Returns the server artifact.""" self._server_artifact, latest = self._api.create_artifact( type, name, self._digest, metadata=metadata, ttl_duration_seconds=ttl_duration_seconds, aliases=alias_specs, tags=tag_specs, description=description, is_user_created=self._is_user_created, distributed_id=distributed_id, client_id=client_id, sequence_client_id=sequence_client_id, history_step=history_step, ) assert self._server_artifact is not None # mypy optionality unwrapper artifact_id = self._server_artifact["id"] if base_id is None and latest: base_id = latest["id"] if self._server_artifact["state"] == "COMMITTED": if use_after_commit: self._api.use_artifact( artifact_id, artifact_entity_name=entity, artifact_project_name=project, ) return self._server_artifact if ( self._server_artifact["state"] != "PENDING" # For old servers, see https://github.com/wandb/wandb/pull/6190 and self._server_artifact["state"] != "DELETED" ): raise Exception( 'Unknown artifact state "{}"'.format(self._server_artifact["state"]) ) manifest_type = "FULL" manifest_filename = "wandb_manifest.json" if incremental: manifest_type = "INCREMENTAL" manifest_filename = "wandb_manifest.incremental.json" elif distributed_id: manifest_type = "PATCH" manifest_filename = "wandb_manifest.patch.json" artifact_manifest_id, _ = self._api.create_artifact_manifest( manifest_filename, "", artifact_id, base_artifact_id=base_id, include_upload=False, type=manifest_type, ) step_prepare = wandb.filesync.step_prepare.StepPrepare( self._api, 0.1, 0.01, 1000 ) # TODO: params step_prepare.start() # Upload Artifact "L1" files, the actual artifact contents self._file_pusher.store_manifest_files( self._manifest, artifact_id, lambda entry, progress_callback: self._manifest.storage_policy.store_file( artifact_id, artifact_manifest_id, entry, step_prepare, progress_callback=progress_callback, ), ) def before_commit() -> None: self._resolve_client_id_manifest_references() with tempfile.NamedTemporaryFile("w+", suffix=".json", delete=False) as fp: path = os.path.abspath(fp.name) json.dump(self._manifest.to_manifest_json(), fp, indent=4) digest = md5_file_b64(path) if distributed_id or incremental: # If we're in the distributed flow, we want to update the # patch manifest we created with our finalized digest. _, resp = self._api.update_artifact_manifest( artifact_manifest_id, digest=digest, ) else: # In the regular flow, we can recreate the full manifest with the # updated digest. # # NOTE: We do this for backwards compatibility with older backends # that don't support the 'updateArtifactManifest' API. _, resp = self._api.create_artifact_manifest( manifest_filename, digest, artifact_id, base_artifact_id=base_id, ) # We're duplicating the file upload logic a little, which isn't great. upload_url = resp["uploadUrl"] upload_headers = resp["uploadHeaders"] extra_headers = {} for upload_header in upload_headers: key, val = upload_header.split(":", 1) extra_headers[key] = val with open(path, "rb") as fp2: self._api.upload_file_retry( upload_url, fp2, extra_headers=extra_headers, ) commit_result: concurrent.futures.Future[None] = concurrent.futures.Future() # This will queue the commit. It will only happen after all the file uploads are done self._file_pusher.commit_artifact( artifact_id, finalize=finalize, before_commit=before_commit, result_future=commit_result, ) # Block until all artifact files are uploaded and the # artifact is committed. try: commit_result.result() finally: step_prepare.shutdown() if finalize and use_after_commit: self._api.use_artifact( artifact_id, artifact_entity_name=entity, artifact_project_name=project, ) return self._server_artifact def _resolve_client_id_manifest_references(self) -> None: for entry_path in self._manifest.entries: entry = self._manifest.entries[entry_path] if entry.ref is not None: if entry.ref.startswith("wandb-client-artifact:"): client_id = util.host_from_path(entry.ref) artifact_file_path = util.uri_from_path(entry.ref) artifact_id = self._api._resolve_client_id(client_id) if artifact_id is None: raise RuntimeError(f"Could not resolve client id {client_id}") entry.ref = URIStr( f"wandb-artifact://{b64_to_hex_id(B64MD5(artifact_id))}/{artifact_file_path}" )