File size: 2,545 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
"""Artifact manifest."""
from __future__ import annotations
from typing import TYPE_CHECKING, Mapping
from wandb.sdk.internal.internal_api import Api as InternalApi
from wandb.sdk.lib.hashutil import HexMD5
if TYPE_CHECKING:
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
from wandb.sdk.artifacts.storage_policy import StoragePolicy
class ArtifactManifest:
entries: dict[str, ArtifactManifestEntry]
@classmethod
def from_manifest_json(
cls, manifest_json: dict, api: InternalApi | None = None
) -> ArtifactManifest:
if "version" not in manifest_json:
raise ValueError("Invalid manifest format. Must contain version field.")
version = manifest_json["version"]
for sub in cls.__subclasses__():
if sub.version() == version:
return sub.from_manifest_json(manifest_json, api=api)
raise ValueError("Invalid manifest version.")
@classmethod
def version(cls) -> int:
raise NotImplementedError
def __init__(
self,
storage_policy: StoragePolicy,
entries: Mapping[str, ArtifactManifestEntry] | None = None,
) -> None:
self.storage_policy = storage_policy
self.entries = dict(entries) if entries else {}
def __len__(self) -> int:
return len(self.entries)
def to_manifest_json(self) -> dict:
raise NotImplementedError
def digest(self) -> HexMD5:
raise NotImplementedError
def add_entry(self, entry: ArtifactManifestEntry, overwrite: bool = False) -> None:
path = entry.path
if not overwrite:
prev_entry = self.entries.get(path)
if prev_entry and (entry.digest != prev_entry.digest):
raise ValueError(f"Cannot add the same path twice: {path!r}")
self.entries[path] = entry
def remove_entry(self, entry: ArtifactManifestEntry) -> None:
try:
del self.entries[entry.path]
except LookupError:
raise FileNotFoundError(f"Cannot remove missing entry: '{entry.path}'")
def get_entry_by_path(self, path: str) -> ArtifactManifestEntry | None:
return self.entries.get(path)
def get_entries_in_directory(self, directory: str) -> list[ArtifactManifestEntry]:
return [
self.entries[entry_key]
for entry_key in self.entries
if entry_key.startswith(
directory + "/"
) # entries use forward slash even for windows
]
|