File size: 4,672 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""Batching file prepare requests to our API."""
import concurrent.futures
import functools
import os
import queue
import shutil
import threading
from typing import TYPE_CHECKING, NamedTuple, Optional, Union, cast
from wandb.filesync import step_upload
from wandb.sdk.lib import filesystem, runid
from wandb.sdk.lib.paths import LogicalPath
if TYPE_CHECKING:
import tempfile
from wandb.filesync import stats
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
from wandb.sdk.artifacts.artifact_saver import SaveFn
from wandb.sdk.internal import internal_api
class RequestUpload(NamedTuple):
path: str
save_name: LogicalPath
copy: bool
class RequestStoreManifestFiles(NamedTuple):
manifest: "ArtifactManifest"
artifact_id: str
save_fn: "SaveFn"
class RequestCommitArtifact(NamedTuple):
artifact_id: str
finalize: bool
before_commit: step_upload.PreCommitFn
result_future: "concurrent.futures.Future[None]"
class RequestFinish(NamedTuple):
callback: Optional[step_upload.OnRequestFinishFn]
Event = Union[
RequestUpload, RequestStoreManifestFiles, RequestCommitArtifact, RequestFinish
]
class StepChecksum:
def __init__(
self,
api: "internal_api.Api",
tempdir: "tempfile.TemporaryDirectory",
request_queue: "queue.Queue[Event]",
output_queue: "queue.Queue[step_upload.Event]",
stats: "stats.Stats",
) -> None:
self._api = api
self._tempdir = tempdir
self._request_queue = request_queue
self._output_queue = output_queue
self._stats = stats
self._thread = threading.Thread(target=self._thread_body)
self._thread.daemon = True
def _thread_body(self) -> None:
while True:
req = self._request_queue.get()
if isinstance(req, RequestUpload):
path = req.path
if req.copy:
path = os.path.join(
self._tempdir.name,
f"{runid.generate_id()}-{req.save_name}",
)
filesystem.mkdir_exists_ok(os.path.dirname(path))
try:
# certain linux distros throw an exception when copying
# large files: https://bugs.python.org/issue43743
shutil.copy2(req.path, path)
except OSError:
shutil._USE_CP_SENDFILE = False # type: ignore[attr-defined]
shutil.copy2(req.path, path)
self._stats.init_file(req.save_name, os.path.getsize(path))
self._output_queue.put(
step_upload.RequestUpload(
path,
req.save_name,
None,
None,
req.copy,
None,
None,
)
)
elif isinstance(req, RequestStoreManifestFiles):
for entry in req.manifest.entries.values():
if entry.local_path:
self._stats.init_file(
entry.local_path,
cast(int, entry.size),
is_artifact_file=True,
)
self._output_queue.put(
step_upload.RequestUpload(
entry.local_path,
entry.path,
req.artifact_id,
entry.digest,
False,
functools.partial(req.save_fn, entry),
entry.digest,
)
)
elif isinstance(req, RequestCommitArtifact):
self._output_queue.put(
step_upload.RequestCommitArtifact(
req.artifact_id,
req.finalize,
req.before_commit,
req.result_future,
)
)
elif isinstance(req, RequestFinish):
break
else:
raise TypeError
self._output_queue.put(step_upload.RequestFinish(req.callback))
def start(self) -> None:
self._thread.start()
def is_alive(self) -> bool:
return self._thread.is_alive()
def finish(self) -> None:
self._request_queue.put(RequestFinish(None))
|