|
"""Artifact saver.""" |
|
|
|
from __future__ import annotations |
|
|
|
import concurrent.futures |
|
import json |
|
import os |
|
import tempfile |
|
from typing import TYPE_CHECKING, Awaitable, Sequence |
|
|
|
import wandb |
|
import wandb.filesync.step_prepare |
|
from wandb import util |
|
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest |
|
from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64 |
|
from wandb.sdk.lib.paths import URIStr |
|
|
|
if TYPE_CHECKING: |
|
from typing import Protocol |
|
|
|
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry |
|
from wandb.sdk.internal.file_pusher import FilePusher |
|
from wandb.sdk.internal.internal_api import Api as InternalApi |
|
from wandb.sdk.internal.progress import ProgressFn |
|
|
|
class SaveFn(Protocol): |
|
def __call__( |
|
self, entry: ArtifactManifestEntry, progress_callback: ProgressFn |
|
) -> bool: |
|
pass |
|
|
|
class SaveFnAsync(Protocol): |
|
def __call__( |
|
self, entry: ArtifactManifestEntry, progress_callback: ProgressFn |
|
) -> Awaitable[bool]: |
|
pass |
|
|
|
|
|
class ArtifactSaver: |
|
_server_artifact: dict | None |
|
|
|
def __init__( |
|
self, |
|
api: InternalApi, |
|
digest: str, |
|
manifest_json: dict, |
|
file_pusher: FilePusher, |
|
is_user_created: bool = False, |
|
) -> None: |
|
self._api = api |
|
self._file_pusher = file_pusher |
|
self._digest = digest |
|
self._manifest = ArtifactManifest.from_manifest_json( |
|
manifest_json, |
|
api=self._api, |
|
) |
|
self._is_user_created = is_user_created |
|
self._server_artifact = None |
|
|
|
def save( |
|
self, |
|
entity: str, |
|
project: str, |
|
type: str, |
|
name: str, |
|
client_id: str, |
|
sequence_client_id: str, |
|
distributed_id: str | None = None, |
|
finalize: bool = True, |
|
metadata: dict | None = None, |
|
ttl_duration_seconds: int | None = None, |
|
description: str | None = None, |
|
aliases: Sequence[str] | None = None, |
|
tags: Sequence[str] | None = None, |
|
use_after_commit: bool = False, |
|
incremental: bool = False, |
|
history_step: int | None = None, |
|
base_id: str | None = None, |
|
) -> dict | None: |
|
return self._save_internal( |
|
entity, |
|
project, |
|
type, |
|
name, |
|
client_id, |
|
sequence_client_id, |
|
distributed_id, |
|
finalize, |
|
metadata, |
|
ttl_duration_seconds, |
|
description, |
|
aliases, |
|
tags, |
|
use_after_commit, |
|
incremental, |
|
history_step, |
|
base_id, |
|
) |
|
|
|
def _save_internal( |
|
self, |
|
entity: str, |
|
project: str, |
|
type: str, |
|
name: str, |
|
client_id: str, |
|
sequence_client_id: str, |
|
distributed_id: str | None = None, |
|
finalize: bool = True, |
|
metadata: dict | None = None, |
|
ttl_duration_seconds: int | None = None, |
|
description: str | None = None, |
|
aliases: Sequence[str] | None = None, |
|
tags: Sequence[str] | None = None, |
|
use_after_commit: bool = False, |
|
incremental: bool = False, |
|
history_step: int | None = None, |
|
base_id: str | None = None, |
|
) -> dict | None: |
|
alias_specs = [] |
|
for alias in aliases or []: |
|
alias_specs.append({"artifactCollectionName": name, "alias": alias}) |
|
|
|
tag_specs = [{"tagName": tag} for tag in tags or []] |
|
|
|
"""Returns the server artifact.""" |
|
self._server_artifact, latest = self._api.create_artifact( |
|
type, |
|
name, |
|
self._digest, |
|
metadata=metadata, |
|
ttl_duration_seconds=ttl_duration_seconds, |
|
aliases=alias_specs, |
|
tags=tag_specs, |
|
description=description, |
|
is_user_created=self._is_user_created, |
|
distributed_id=distributed_id, |
|
client_id=client_id, |
|
sequence_client_id=sequence_client_id, |
|
history_step=history_step, |
|
) |
|
|
|
assert self._server_artifact is not None |
|
artifact_id = self._server_artifact["id"] |
|
if base_id is None and latest: |
|
base_id = latest["id"] |
|
if self._server_artifact["state"] == "COMMITTED": |
|
if use_after_commit: |
|
self._api.use_artifact( |
|
artifact_id, |
|
artifact_entity_name=entity, |
|
artifact_project_name=project, |
|
) |
|
return self._server_artifact |
|
if ( |
|
self._server_artifact["state"] != "PENDING" |
|
|
|
and self._server_artifact["state"] != "DELETED" |
|
): |
|
raise Exception( |
|
'Unknown artifact state "{}"'.format(self._server_artifact["state"]) |
|
) |
|
|
|
manifest_type = "FULL" |
|
manifest_filename = "wandb_manifest.json" |
|
if incremental: |
|
manifest_type = "INCREMENTAL" |
|
manifest_filename = "wandb_manifest.incremental.json" |
|
elif distributed_id: |
|
manifest_type = "PATCH" |
|
manifest_filename = "wandb_manifest.patch.json" |
|
artifact_manifest_id, _ = self._api.create_artifact_manifest( |
|
manifest_filename, |
|
"", |
|
artifact_id, |
|
base_artifact_id=base_id, |
|
include_upload=False, |
|
type=manifest_type, |
|
) |
|
|
|
step_prepare = wandb.filesync.step_prepare.StepPrepare( |
|
self._api, 0.1, 0.01, 1000 |
|
) |
|
step_prepare.start() |
|
|
|
|
|
self._file_pusher.store_manifest_files( |
|
self._manifest, |
|
artifact_id, |
|
lambda entry, progress_callback: self._manifest.storage_policy.store_file( |
|
artifact_id, |
|
artifact_manifest_id, |
|
entry, |
|
step_prepare, |
|
progress_callback=progress_callback, |
|
), |
|
) |
|
|
|
def before_commit() -> None: |
|
self._resolve_client_id_manifest_references() |
|
with tempfile.NamedTemporaryFile("w+", suffix=".json", delete=False) as fp: |
|
path = os.path.abspath(fp.name) |
|
json.dump(self._manifest.to_manifest_json(), fp, indent=4) |
|
digest = md5_file_b64(path) |
|
if distributed_id or incremental: |
|
|
|
|
|
_, resp = self._api.update_artifact_manifest( |
|
artifact_manifest_id, |
|
digest=digest, |
|
) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
_, resp = self._api.create_artifact_manifest( |
|
manifest_filename, |
|
digest, |
|
artifact_id, |
|
base_artifact_id=base_id, |
|
) |
|
|
|
|
|
upload_url = resp["uploadUrl"] |
|
upload_headers = resp["uploadHeaders"] |
|
extra_headers = {} |
|
for upload_header in upload_headers: |
|
key, val = upload_header.split(":", 1) |
|
extra_headers[key] = val |
|
with open(path, "rb") as fp2: |
|
self._api.upload_file_retry( |
|
upload_url, |
|
fp2, |
|
extra_headers=extra_headers, |
|
) |
|
|
|
commit_result: concurrent.futures.Future[None] = concurrent.futures.Future() |
|
|
|
|
|
self._file_pusher.commit_artifact( |
|
artifact_id, |
|
finalize=finalize, |
|
before_commit=before_commit, |
|
result_future=commit_result, |
|
) |
|
|
|
|
|
|
|
try: |
|
commit_result.result() |
|
finally: |
|
step_prepare.shutdown() |
|
|
|
if finalize and use_after_commit: |
|
self._api.use_artifact( |
|
artifact_id, |
|
artifact_entity_name=entity, |
|
artifact_project_name=project, |
|
) |
|
|
|
return self._server_artifact |
|
|
|
def _resolve_client_id_manifest_references(self) -> None: |
|
for entry_path in self._manifest.entries: |
|
entry = self._manifest.entries[entry_path] |
|
if entry.ref is not None: |
|
if entry.ref.startswith("wandb-client-artifact:"): |
|
client_id = util.host_from_path(entry.ref) |
|
artifact_file_path = util.uri_from_path(entry.ref) |
|
artifact_id = self._api._resolve_client_id(client_id) |
|
if artifact_id is None: |
|
raise RuntimeError(f"Could not resolve client id {client_id}") |
|
entry.ref = URIStr( |
|
f"wandb-artifact://{b64_to_hex_id(B64MD5(artifact_id))}/{artifact_file_path}" |
|
) |
|
|