jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""Interface base class - Used to send messages to the internal process.
InterfaceBase: The abstract class
InterfaceShared: Common routines for socket and queue based implementations
InterfaceQueue: Use multiprocessing queues to send and receive messages
InterfaceSock: Use socket to send and receive messages
"""
import gzip
import logging
import time
from abc import abstractmethod
from pathlib import Path
from secrets import token_hex
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Literal,
NewType,
Optional,
Tuple,
TypedDict,
Union,
)
from wandb import termwarn
from wandb.proto import wandb_internal_pb2 as pb
from wandb.proto import wandb_telemetry_pb2 as tpb
from wandb.sdk.artifacts.artifact import Artifact
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
from wandb.sdk.artifacts.staging import get_staging_dir
from wandb.sdk.lib import json_util as json
from wandb.sdk.mailbox import HandleAbandonedError, MailboxHandle
from wandb.util import (
WandBJSONEncoderOld,
get_h5_typename,
json_dumps_safer,
json_dumps_safer_history,
json_friendly,
json_friendly_val,
maybe_compress_summary,
)
from ..data_types.utils import history_dict_to_json, val_to_json
from . import summary_record as sr
MANIFEST_FILE_SIZE_THRESHOLD = 100_000
GlobStr = NewType("GlobStr", str)
PolicyName = Literal["now", "live", "end"]
class FilesDict(TypedDict):
files: Iterable[Tuple[GlobStr, PolicyName]]
if TYPE_CHECKING:
from ..wandb_run import Run
logger = logging.getLogger("wandb")
def file_policy_to_enum(policy: "PolicyName") -> "pb.FilesItem.PolicyType.V":
if policy == "now":
enum = pb.FilesItem.PolicyType.NOW
elif policy == "end":
enum = pb.FilesItem.PolicyType.END
elif policy == "live":
enum = pb.FilesItem.PolicyType.LIVE
return enum
def file_enum_to_policy(enum: "pb.FilesItem.PolicyType.V") -> "PolicyName":
if enum == pb.FilesItem.PolicyType.NOW:
policy: PolicyName = "now"
elif enum == pb.FilesItem.PolicyType.END:
policy = "end"
elif enum == pb.FilesItem.PolicyType.LIVE:
policy = "live"
return policy
class InterfaceBase:
_drop: bool
def __init__(self) -> None:
self._drop = False
def publish_header(self) -> None:
header = pb.HeaderRecord()
self._publish_header(header)
@abstractmethod
def _publish_header(self, header: pb.HeaderRecord) -> None:
raise NotImplementedError
def deliver_status(self) -> MailboxHandle[pb.Result]:
return self._deliver_status(pb.StatusRequest())
@abstractmethod
def _deliver_status(
self,
status: pb.StatusRequest,
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def _make_config(
self,
data: Optional[dict] = None,
key: Optional[Union[Tuple[str, ...], str]] = None,
val: Optional[Any] = None,
obj: Optional[pb.ConfigRecord] = None,
) -> pb.ConfigRecord:
config = obj or pb.ConfigRecord()
if data:
for k, v in data.items():
update = config.update.add()
update.key = k
update.value_json = json_dumps_safer(json_friendly(v)[0])
if key:
update = config.update.add()
if isinstance(key, tuple):
for k in key:
update.nested_key.append(k)
else:
update.key = key
update.value_json = json_dumps_safer(json_friendly(val)[0])
return config
def _make_run(self, run: "Run") -> pb.RunRecord: # noqa: C901
proto_run = pb.RunRecord()
if run._settings.entity is not None:
proto_run.entity = run._settings.entity
if run._settings.project is not None:
proto_run.project = run._settings.project
if run._settings.run_group is not None:
proto_run.run_group = run._settings.run_group
if run._settings.run_job_type is not None:
proto_run.job_type = run._settings.run_job_type
if run._settings.run_id is not None:
proto_run.run_id = run._settings.run_id
if run._settings.run_name is not None:
proto_run.display_name = run._settings.run_name
if run._settings.run_notes is not None:
proto_run.notes = run._settings.run_notes
if run._settings.run_tags is not None:
for tag in run._settings.run_tags:
proto_run.tags.append(tag)
if run._start_time is not None:
proto_run.start_time.FromMicroseconds(int(run._start_time * 1e6))
if run._starting_step is not None:
proto_run.starting_step = run._starting_step
if run._settings.git_remote_url is not None:
proto_run.git.remote_url = run._settings.git_remote_url
if run._settings.git_commit is not None:
proto_run.git.commit = run._settings.git_commit
if run._settings.sweep_id is not None:
proto_run.sweep_id = run._settings.sweep_id
if run._settings.host:
proto_run.host = run._settings.host
if run._settings.resumed:
proto_run.resumed = run._settings.resumed
if run._settings.fork_from:
run_moment = run._settings.fork_from
proto_run.branch_point.run = run_moment.run
proto_run.branch_point.metric = run_moment.metric
proto_run.branch_point.value = run_moment.value
if run._settings.resume_from:
run_moment = run._settings.resume_from
proto_run.branch_point.run = run_moment.run
proto_run.branch_point.metric = run_moment.metric
proto_run.branch_point.value = run_moment.value
if run._forked:
proto_run.forked = run._forked
if run._config is not None:
config_dict = run._config._as_dict() # type: ignore
self._make_config(data=config_dict, obj=proto_run.config)
if run._telemetry_obj:
proto_run.telemetry.MergeFrom(run._telemetry_obj)
if run._start_runtime:
proto_run.runtime = run._start_runtime
return proto_run
def publish_run(self, run: "Run") -> None:
run_record = self._make_run(run)
self._publish_run(run_record)
@abstractmethod
def _publish_run(self, run: pb.RunRecord) -> None:
raise NotImplementedError
def publish_cancel(self, cancel_slot: str) -> None:
cancel = pb.CancelRequest(cancel_slot=cancel_slot)
self._publish_cancel(cancel)
@abstractmethod
def _publish_cancel(self, cancel: pb.CancelRequest) -> None:
raise NotImplementedError
def publish_config(
self,
data: Optional[dict] = None,
key: Optional[Union[Tuple[str, ...], str]] = None,
val: Optional[Any] = None,
) -> None:
cfg = self._make_config(data=data, key=key, val=val)
self._publish_config(cfg)
@abstractmethod
def _publish_config(self, cfg: pb.ConfigRecord) -> None:
raise NotImplementedError
def publish_metadata(self, metadata: pb.MetadataRequest) -> None:
self._publish_metadata(metadata)
@abstractmethod
def _publish_metadata(self, metadata: pb.MetadataRequest) -> None:
raise NotImplementedError
@abstractmethod
def _publish_metric(self, metric: pb.MetricRecord) -> None:
raise NotImplementedError
def _make_summary_from_dict(self, summary_dict: dict) -> pb.SummaryRecord:
summary = pb.SummaryRecord()
for k, v in summary_dict.items():
update = summary.update.add()
update.key = k
update.value_json = json.dumps(v)
return summary
def _summary_encode(
self,
value: Any,
path_from_root: str,
run: "Run",
) -> dict:
"""Normalize, compress, and encode sub-objects for backend storage.
value: Object to encode.
path_from_root: `str` dot separated string from the top-level summary to the
current `value`.
Returns:
A new tree of dict's with large objects replaced with dictionaries
with "_type" entries that say which type the original data was.
"""
# Constructs a new `dict` tree in `json_value` that discards and/or
# encodes objects that aren't JSON serializable.
if isinstance(value, dict):
json_value = {}
for key, value in value.items(): # noqa: B020
json_value[key] = self._summary_encode(
value,
path_from_root + "." + key,
run=run,
)
return json_value
else:
friendly_value, converted = json_friendly(
val_to_json(run, path_from_root, value, namespace="summary")
)
json_value, compressed = maybe_compress_summary(
friendly_value, get_h5_typename(value)
)
if compressed:
# TODO(jhr): impleement me
pass
# self.write_h5(path_from_root, friendly_value)
return json_value
def _make_summary(
self,
summary_record: sr.SummaryRecord,
run: "Run",
) -> pb.SummaryRecord:
pb_summary_record = pb.SummaryRecord()
for item in summary_record.update:
pb_summary_item = pb_summary_record.update.add()
key_length = len(item.key)
assert key_length > 0
if key_length > 1:
pb_summary_item.nested_key.extend(item.key)
else:
pb_summary_item.key = item.key[0]
path_from_root = ".".join(item.key)
json_value = self._summary_encode(
item.value,
path_from_root,
run=run,
)
json_value, _ = json_friendly(json_value) # type: ignore
pb_summary_item.value_json = json.dumps(
json_value,
cls=WandBJSONEncoderOld,
)
for item in summary_record.remove:
pb_summary_item = pb_summary_record.remove.add()
key_length = len(item.key)
assert key_length > 0
if key_length > 1:
pb_summary_item.nested_key.extend(item.key)
else:
pb_summary_item.key = item.key[0]
return pb_summary_record
def publish_summary(
self,
run: "Run",
summary_record: sr.SummaryRecord,
) -> None:
pb_summary_record = self._make_summary(summary_record, run=run)
self._publish_summary(pb_summary_record)
@abstractmethod
def _publish_summary(self, summary: pb.SummaryRecord) -> None:
raise NotImplementedError
def _make_files(self, files_dict: "FilesDict") -> pb.FilesRecord:
files = pb.FilesRecord()
for path, policy in files_dict["files"]:
f = files.files.add()
f.path = path
f.policy = file_policy_to_enum(policy)
return files
def publish_files(self, files_dict: "FilesDict") -> None:
files = self._make_files(files_dict)
self._publish_files(files)
@abstractmethod
def _publish_files(self, files: pb.FilesRecord) -> None:
raise NotImplementedError
def publish_python_packages(self, working_set) -> None:
python_packages = pb.PythonPackagesRequest()
for pkg in working_set:
python_packages.package.add(name=pkg.key, version=pkg.version)
self._publish_python_packages(python_packages)
@abstractmethod
def _publish_python_packages(
self, python_packages: pb.PythonPackagesRequest
) -> None:
raise NotImplementedError
def _make_artifact(self, artifact: "Artifact") -> pb.ArtifactRecord:
proto_artifact = pb.ArtifactRecord()
proto_artifact.type = artifact.type
proto_artifact.name = artifact.name
proto_artifact.client_id = artifact._client_id
proto_artifact.sequence_client_id = artifact._sequence_client_id
proto_artifact.digest = artifact.digest
if artifact.distributed_id:
proto_artifact.distributed_id = artifact.distributed_id
if artifact.description:
proto_artifact.description = artifact.description
if artifact.metadata:
proto_artifact.metadata = json.dumps(json_friendly_val(artifact.metadata))
if artifact._base_id:
proto_artifact.base_id = artifact._base_id
ttl_duration_input = artifact._ttl_duration_seconds_to_gql()
if ttl_duration_input:
proto_artifact.ttl_duration_seconds = ttl_duration_input
proto_artifact.incremental_beta1 = artifact.incremental
self._make_artifact_manifest(artifact.manifest, obj=proto_artifact.manifest)
return proto_artifact
def _make_artifact_manifest(
self,
artifact_manifest: ArtifactManifest,
obj: Optional[pb.ArtifactManifest] = None,
) -> pb.ArtifactManifest:
proto_manifest = obj or pb.ArtifactManifest()
proto_manifest.version = artifact_manifest.version()
proto_manifest.storage_policy = artifact_manifest.storage_policy.name()
# Very large manifests need to be written to file to avoid protobuf size limits.
if len(artifact_manifest) > MANIFEST_FILE_SIZE_THRESHOLD:
path = self._write_artifact_manifest_file(artifact_manifest)
proto_manifest.manifest_file_path = path
return proto_manifest
for k, v in artifact_manifest.storage_policy.config().items() or {}.items():
cfg = proto_manifest.storage_policy_config.add()
cfg.key = k
cfg.value_json = json.dumps(v)
for entry in sorted(artifact_manifest.entries.values(), key=lambda k: k.path):
proto_entry = proto_manifest.contents.add()
proto_entry.path = entry.path
proto_entry.digest = entry.digest
if entry.size:
proto_entry.size = entry.size
if entry.birth_artifact_id:
proto_entry.birth_artifact_id = entry.birth_artifact_id
if entry.ref:
proto_entry.ref = entry.ref
if entry.local_path:
proto_entry.local_path = entry.local_path
proto_entry.skip_cache = entry.skip_cache
for k, v in entry.extra.items():
proto_extra = proto_entry.extra.add()
proto_extra.key = k
proto_extra.value_json = json.dumps(v)
return proto_manifest
def _write_artifact_manifest_file(self, manifest: ArtifactManifest) -> str:
manifest_dir = Path(get_staging_dir()) / "artifact_manifests"
manifest_dir.mkdir(parents=True, exist_ok=True)
# It would be simpler to use `manifest.to_json()`, but that gets very slow for
# large manifests since it encodes the whole thing as a single JSON object.
filename = f"{time.time()}_{token_hex(8)}.manifest_contents.jl.gz"
manifest_file_path = manifest_dir / filename
with gzip.open(manifest_file_path, mode="wt", compresslevel=1) as f:
for entry in manifest.entries.values():
f.write(f"{json.dumps(entry.to_json())}\n")
return str(manifest_file_path)
def deliver_link_artifact(
self,
artifact: "Artifact",
portfolio_name: str,
aliases: Iterable[str],
entity: Optional[str] = None,
project: Optional[str] = None,
organization: Optional[str] = None,
) -> MailboxHandle[pb.Result]:
link_artifact = pb.LinkArtifactRequest()
if artifact.is_draft():
link_artifact.client_id = artifact._client_id
else:
link_artifact.server_id = artifact.id if artifact.id else ""
link_artifact.portfolio_name = portfolio_name
link_artifact.portfolio_entity = entity or ""
link_artifact.portfolio_organization = organization or ""
link_artifact.portfolio_project = project or ""
link_artifact.portfolio_aliases.extend(aliases)
return self._deliver_link_artifact(link_artifact)
@abstractmethod
def _deliver_link_artifact(
self, link_artifact: pb.LinkArtifactRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
@staticmethod
def _make_partial_source_str(
source: Any, job_info: Dict[str, Any], metadata: Dict[str, Any]
) -> str:
"""Construct use_artifact.partial.source_info.source as str."""
source_type = job_info.get("source_type", "").strip()
if source_type == "artifact":
info_source = job_info.get("source", {})
source.artifact.artifact = info_source.get("artifact", "")
source.artifact.entrypoint.extend(info_source.get("entrypoint", []))
source.artifact.notebook = info_source.get("notebook", False)
build_context = info_source.get("build_context")
if build_context:
source.artifact.build_context = build_context
dockerfile = info_source.get("dockerfile")
if dockerfile:
source.artifact.dockerfile = dockerfile
elif source_type == "repo":
source.git.git_info.remote = metadata.get("git", {}).get("remote", "")
source.git.git_info.commit = metadata.get("git", {}).get("commit", "")
source.git.entrypoint.extend(metadata.get("entrypoint", []))
source.git.notebook = metadata.get("notebook", False)
build_context = metadata.get("build_context")
if build_context:
source.git.build_context = build_context
dockerfile = metadata.get("dockerfile")
if dockerfile:
source.git.dockerfile = dockerfile
elif source_type == "image":
source.image.image = metadata.get("docker", "")
else:
raise ValueError("Invalid source type")
source_str: str = source.SerializeToString()
return source_str
def _make_proto_use_artifact(
self,
use_artifact: pb.UseArtifactRecord,
job_name: str,
job_info: Dict[str, Any],
metadata: Dict[str, Any],
) -> pb.UseArtifactRecord:
use_artifact.partial.job_name = job_name
use_artifact.partial.source_info._version = job_info.get("_version", "")
use_artifact.partial.source_info.source_type = job_info.get("source_type", "")
use_artifact.partial.source_info.runtime = job_info.get("runtime", "")
src_str = self._make_partial_source_str(
source=use_artifact.partial.source_info.source,
job_info=job_info,
metadata=metadata,
)
use_artifact.partial.source_info.source.ParseFromString(src_str) # type: ignore[arg-type]
return use_artifact
def publish_use_artifact(
self,
artifact: "Artifact",
) -> None:
assert artifact.id is not None, "Artifact must have an id"
use_artifact = pb.UseArtifactRecord(
id=artifact.id,
type=artifact.type,
name=artifact.name,
)
# TODO(gst): move to internal process
if "_partial" in artifact.metadata:
# Download source info from logged partial job artifact
job_info = {}
try:
path = artifact.get_entry("wandb-job.json").download()
with open(path) as f:
job_info = json.load(f)
except Exception as e:
logger.warning(
f"Failed to download partial job info from artifact {artifact}, : {e}"
)
termwarn(
f"Failed to download partial job info from artifact {artifact}, : {e}"
)
return
try:
use_artifact = self._make_proto_use_artifact(
use_artifact=use_artifact,
job_name=artifact.name,
job_info=job_info,
metadata=artifact.metadata,
)
except Exception as e:
logger.warning(f"Failed to construct use artifact proto: {e}")
termwarn(f"Failed to construct use artifact proto: {e}")
return
self._publish_use_artifact(use_artifact)
@abstractmethod
def _publish_use_artifact(self, proto_artifact: pb.UseArtifactRecord) -> None:
raise NotImplementedError
def deliver_artifact(
self,
run: "Run",
artifact: "Artifact",
aliases: Iterable[str],
tags: Optional[Iterable[str]] = None,
history_step: Optional[int] = None,
is_user_created: bool = False,
use_after_commit: bool = False,
finalize: bool = True,
) -> MailboxHandle[pb.Result]:
proto_run = self._make_run(run)
proto_artifact = self._make_artifact(artifact)
proto_artifact.run_id = proto_run.run_id
proto_artifact.project = proto_run.project
proto_artifact.entity = proto_run.entity
proto_artifact.user_created = is_user_created
proto_artifact.use_after_commit = use_after_commit
proto_artifact.finalize = finalize
proto_artifact.aliases.extend(aliases or [])
proto_artifact.tags.extend(tags or [])
log_artifact = pb.LogArtifactRequest()
log_artifact.artifact.CopyFrom(proto_artifact)
if history_step is not None:
log_artifact.history_step = history_step
log_artifact.staging_dir = get_staging_dir()
resp = self._deliver_artifact(log_artifact)
return resp
@abstractmethod
def _deliver_artifact(
self,
log_artifact: pb.LogArtifactRequest,
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_download_artifact(
self,
artifact_id: str,
download_root: str,
allow_missing_references: bool,
skip_cache: bool,
path_prefix: Optional[str],
) -> MailboxHandle[pb.Result]:
download_artifact = pb.DownloadArtifactRequest()
download_artifact.artifact_id = artifact_id
download_artifact.download_root = download_root
download_artifact.allow_missing_references = allow_missing_references
download_artifact.skip_cache = skip_cache
download_artifact.path_prefix = path_prefix or ""
resp = self._deliver_download_artifact(download_artifact)
return resp
@abstractmethod
def _deliver_download_artifact(
self, download_artifact: pb.DownloadArtifactRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def publish_artifact(
self,
run: "Run",
artifact: "Artifact",
aliases: Iterable[str],
tags: Optional[Iterable[str]] = None,
is_user_created: bool = False,
use_after_commit: bool = False,
finalize: bool = True,
) -> None:
proto_run = self._make_run(run)
proto_artifact = self._make_artifact(artifact)
proto_artifact.run_id = proto_run.run_id
proto_artifact.project = proto_run.project
proto_artifact.entity = proto_run.entity
proto_artifact.user_created = is_user_created
proto_artifact.use_after_commit = use_after_commit
proto_artifact.finalize = finalize
proto_artifact.aliases.extend(aliases or [])
proto_artifact.tags.extend(tags or [])
self._publish_artifact(proto_artifact)
@abstractmethod
def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None:
raise NotImplementedError
def publish_tbdata(self, log_dir: str, save: bool, root_logdir: str = "") -> None:
tbrecord = pb.TBRecord()
tbrecord.log_dir = log_dir
tbrecord.save = save
tbrecord.root_dir = root_logdir
self._publish_tbdata(tbrecord)
@abstractmethod
def _publish_tbdata(self, tbrecord: pb.TBRecord) -> None:
raise NotImplementedError
@abstractmethod
def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None:
raise NotImplementedError
def publish_partial_history(
self,
run: "Run",
data: dict,
user_step: int,
step: Optional[int] = None,
flush: Optional[bool] = None,
publish_step: bool = True,
) -> None:
data = history_dict_to_json(run, data, step=user_step, ignore_copy_err=True)
data.pop("_step", None)
# add timestamp to the history request, if not already present
# the timestamp might come from the tensorboard log logic
if "_timestamp" not in data:
data["_timestamp"] = time.time()
partial_history = pb.PartialHistoryRequest()
for k, v in data.items():
item = partial_history.item.add()
item.key = k
item.value_json = json_dumps_safer_history(v)
if publish_step and step is not None:
partial_history.step.num = step
if flush is not None:
partial_history.action.flush = flush
self._publish_partial_history(partial_history)
@abstractmethod
def _publish_partial_history(self, history: pb.PartialHistoryRequest) -> None:
raise NotImplementedError
def publish_history(
self,
run: "Run",
data: dict,
step: Optional[int] = None,
publish_step: bool = True,
) -> None:
data = history_dict_to_json(run, data, step=step)
history = pb.HistoryRecord()
if publish_step:
assert step is not None
history.step.num = step
data.pop("_step", None)
for k, v in data.items():
item = history.item.add()
item.key = k
item.value_json = json_dumps_safer_history(v)
self._publish_history(history)
@abstractmethod
def _publish_history(self, history: pb.HistoryRecord) -> None:
raise NotImplementedError
def publish_preempting(self) -> None:
preempt_rec = pb.RunPreemptingRecord()
self._publish_preempting(preempt_rec)
@abstractmethod
def _publish_preempting(self, preempt_rec: pb.RunPreemptingRecord) -> None:
raise NotImplementedError
def publish_output(self, name: str, data: str) -> None:
# from vendor.protobuf import google3.protobuf.timestamp
# ts = timestamp.Timestamp()
# ts.GetCurrentTime()
# now = datetime.now()
if name == "stdout":
otype = pb.OutputRecord.OutputType.STDOUT
elif name == "stderr":
otype = pb.OutputRecord.OutputType.STDERR
else:
# TODO(jhr): throw error?
termwarn("unknown type")
o = pb.OutputRecord(output_type=otype, line=data)
o.timestamp.GetCurrentTime()
self._publish_output(o)
@abstractmethod
def _publish_output(self, outdata: pb.OutputRecord) -> None:
raise NotImplementedError
def publish_output_raw(self, name: str, data: str) -> None:
# from vendor.protobuf import google3.protobuf.timestamp
# ts = timestamp.Timestamp()
# ts.GetCurrentTime()
# now = datetime.now()
if name == "stdout":
otype = pb.OutputRawRecord.OutputType.STDOUT
elif name == "stderr":
otype = pb.OutputRawRecord.OutputType.STDERR
else:
# TODO(jhr): throw error?
termwarn("unknown type")
o = pb.OutputRawRecord(output_type=otype, line=data)
o.timestamp.GetCurrentTime()
self._publish_output_raw(o)
@abstractmethod
def _publish_output_raw(self, outdata: pb.OutputRawRecord) -> None:
raise NotImplementedError
def publish_pause(self) -> None:
pause = pb.PauseRequest()
self._publish_pause(pause)
@abstractmethod
def _publish_pause(self, pause: pb.PauseRequest) -> None:
raise NotImplementedError
def publish_resume(self) -> None:
resume = pb.ResumeRequest()
self._publish_resume(resume)
@abstractmethod
def _publish_resume(self, resume: pb.ResumeRequest) -> None:
raise NotImplementedError
def publish_alert(
self, title: str, text: str, level: str, wait_duration: int
) -> None:
proto_alert = pb.AlertRecord()
proto_alert.title = title
proto_alert.text = text
proto_alert.level = level
proto_alert.wait_duration = wait_duration
self._publish_alert(proto_alert)
@abstractmethod
def _publish_alert(self, alert: pb.AlertRecord) -> None:
raise NotImplementedError
def _make_exit(self, exit_code: Optional[int]) -> pb.RunExitRecord:
exit = pb.RunExitRecord()
if exit_code is not None:
exit.exit_code = exit_code
return exit
def publish_exit(self, exit_code: Optional[int]) -> None:
exit_data = self._make_exit(exit_code)
self._publish_exit(exit_data)
@abstractmethod
def _publish_exit(self, exit_data: pb.RunExitRecord) -> None:
raise NotImplementedError
def publish_keepalive(self) -> None:
keepalive = pb.KeepaliveRequest()
self._publish_keepalive(keepalive)
@abstractmethod
def _publish_keepalive(self, keepalive: pb.KeepaliveRequest) -> None:
raise NotImplementedError
def publish_job_input(
self,
include_paths: List[List[str]],
exclude_paths: List[List[str]],
input_schema: Optional[dict],
run_config: bool = False,
file_path: str = "",
):
"""Publishes a request to add inputs to the job.
If run_config is True, the wandb.config will be added as a job input.
If file_path is provided, the file at file_path will be added as a job
input.
The paths provided as arguments are sequences of dictionary keys that
specify a path within the wandb.config. If a path is included, the
corresponding field will be treated as a job input. If a path is
excluded, the corresponding field will not be treated as a job input.
Args:
include_paths: paths within config to include as job inputs.
exclude_paths: paths within config to exclude as job inputs.
input_schema: A JSON Schema describing which attributes will be
editable from the Launch drawer.
run_config: bool indicating whether wandb.config is the input source.
file_path: path to file to include as a job input.
"""
if run_config and file_path:
raise ValueError(
"run_config and file_path are mutually exclusive arguments."
)
request = pb.JobInputRequest()
include_records = [pb.JobInputPath(path=path) for path in include_paths]
exclude_records = [pb.JobInputPath(path=path) for path in exclude_paths]
request.include_paths.extend(include_records)
request.exclude_paths.extend(exclude_records)
source = pb.JobInputSource(
run_config=pb.JobInputSource.RunConfigSource(),
)
if run_config:
source.run_config.CopyFrom(pb.JobInputSource.RunConfigSource())
else:
source.file.CopyFrom(
pb.JobInputSource.ConfigFileSource(path=file_path),
)
request.input_source.CopyFrom(source)
if input_schema:
request.input_schema = json_dumps_safer(input_schema)
return self._publish_job_input(request)
@abstractmethod
def _publish_job_input(
self, request: pb.JobInputRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def join(self) -> None:
# Drop indicates that the internal process has already been shutdown
if self._drop:
return
handle = self._deliver_shutdown()
try:
handle.wait_or(timeout=30)
except TimeoutError:
# This can happen if the server fails to respond due to a bug
# or due to being very busy.
logger.warning("timed out communicating shutdown")
except HandleAbandonedError:
# This can happen if the connection to the server is closed
# before a response is read.
logger.warning("handle abandoned while communicating shutdown")
@abstractmethod
def _deliver_shutdown(self) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_run(self, run: "Run") -> MailboxHandle[pb.Result]:
run_record = self._make_run(run)
return self._deliver_run(run_record)
def deliver_finish_sync(
self,
) -> MailboxHandle[pb.Result]:
sync = pb.SyncFinishRequest()
return self._deliver_finish_sync(sync)
@abstractmethod
def _deliver_finish_sync(
self, sync: pb.SyncFinishRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
@abstractmethod
def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_run_start(self, run: "Run") -> MailboxHandle[pb.Result]:
run_start = pb.RunStartRequest(run=self._make_run(run))
return self._deliver_run_start(run_start)
@abstractmethod
def _deliver_run_start(
self, run_start: pb.RunStartRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_attach(self, attach_id: str) -> MailboxHandle[pb.Result]:
attach = pb.AttachRequest(attach_id=attach_id)
return self._deliver_attach(attach)
@abstractmethod
def _deliver_attach(
self,
status: pb.AttachRequest,
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_stop_status(self) -> MailboxHandle[pb.Result]:
status = pb.StopStatusRequest()
return self._deliver_stop_status(status)
@abstractmethod
def _deliver_stop_status(
self,
status: pb.StopStatusRequest,
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_network_status(self) -> MailboxHandle[pb.Result]:
status = pb.NetworkStatusRequest()
return self._deliver_network_status(status)
@abstractmethod
def _deliver_network_status(
self,
status: pb.NetworkStatusRequest,
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_internal_messages(self) -> MailboxHandle[pb.Result]:
internal_message = pb.InternalMessagesRequest()
return self._deliver_internal_messages(internal_message)
@abstractmethod
def _deliver_internal_messages(
self, internal_message: pb.InternalMessagesRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_get_summary(self) -> MailboxHandle[pb.Result]:
get_summary = pb.GetSummaryRequest()
return self._deliver_get_summary(get_summary)
@abstractmethod
def _deliver_get_summary(
self,
get_summary: pb.GetSummaryRequest,
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_get_system_metrics(self) -> MailboxHandle[pb.Result]:
get_system_metrics = pb.GetSystemMetricsRequest()
return self._deliver_get_system_metrics(get_system_metrics)
@abstractmethod
def _deliver_get_system_metrics(
self, get_summary: pb.GetSystemMetricsRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_get_system_metadata(self) -> MailboxHandle[pb.Result]:
get_system_metadata = pb.GetSystemMetadataRequest()
return self._deliver_get_system_metadata(get_system_metadata)
@abstractmethod
def _deliver_get_system_metadata(
self, get_system_metadata: pb.GetSystemMetadataRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_exit(self, exit_code: Optional[int]) -> MailboxHandle[pb.Result]:
exit_data = self._make_exit(exit_code)
return self._deliver_exit(exit_data)
@abstractmethod
def _deliver_exit(
self,
exit_data: pb.RunExitRecord,
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
@abstractmethod
def deliver_operation_stats(self) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_poll_exit(self) -> MailboxHandle[pb.Result]:
poll_exit = pb.PollExitRequest()
return self._deliver_poll_exit(poll_exit)
@abstractmethod
def _deliver_poll_exit(
self,
poll_exit: pb.PollExitRequest,
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_finish_without_exit(self) -> MailboxHandle[pb.Result]:
run_finish_without_exit = pb.RunFinishWithoutExitRequest()
return self._deliver_finish_without_exit(run_finish_without_exit)
@abstractmethod
def _deliver_finish_without_exit(
self, run_finish_without_exit: pb.RunFinishWithoutExitRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_request_sampled_history(self) -> MailboxHandle[pb.Result]:
sampled_history = pb.SampledHistoryRequest()
return self._deliver_request_sampled_history(sampled_history)
@abstractmethod
def _deliver_request_sampled_history(
self, sampled_history: pb.SampledHistoryRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError
def deliver_request_run_status(self) -> MailboxHandle[pb.Result]:
run_status = pb.RunStatusRequest()
return self._deliver_request_run_status(run_status)
@abstractmethod
def _deliver_request_run_status(
self, run_status: pb.RunStatusRequest
) -> MailboxHandle[pb.Result]:
raise NotImplementedError