|
"""Public API: jobs.""" |
|
|
|
import json |
|
import os |
|
import shutil |
|
import time |
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional |
|
|
|
from wandb_gql import gql |
|
|
|
import wandb |
|
from wandb import util |
|
from wandb.apis import public |
|
from wandb.apis.normalize import normalize_exceptions |
|
from wandb.errors import CommError |
|
from wandb.sdk.artifacts.artifact_state import ArtifactState |
|
from wandb.sdk.data_types._dtypes import InvalidType, Type, TypeRegistry |
|
from wandb.sdk.launch.errors import LaunchError |
|
from wandb.sdk.launch.utils import ( |
|
LAUNCH_DEFAULT_PROJECT, |
|
_fetch_git_repo, |
|
apply_patch, |
|
convert_jupyter_notebook_to_script, |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from wandb.apis.public import Api, RetryingClient |
|
|
|
|
|
class Job: |
|
_name: str |
|
_input_types: Type |
|
_output_types: Type |
|
_entity: str |
|
_project: str |
|
_entrypoint: List[str] |
|
_notebook_job: bool |
|
_partial: bool |
|
|
|
def __init__(self, api: "Api", name, path: Optional[str] = None) -> None: |
|
try: |
|
self._job_artifact = api._artifact(name, type="job") |
|
except CommError: |
|
raise CommError(f"Job artifact {name} not found") |
|
if path: |
|
self._fpath = path |
|
self._job_artifact.download(root=path) |
|
else: |
|
self._fpath = self._job_artifact.download() |
|
self._name = name |
|
self._api = api |
|
self._entity = api.default_entity |
|
|
|
with open(os.path.join(self._fpath, "wandb-job.json")) as f: |
|
self._job_info: Mapping[str, Any] = json.load(f) |
|
source_info = self._job_info.get("source", {}) |
|
|
|
self._notebook_job = source_info.get("notebook", False) |
|
self._entrypoint = source_info.get("entrypoint") |
|
self._dockerfile = source_info.get("dockerfile") |
|
self._build_context = source_info.get("build_context") |
|
self._base_image = source_info.get("base_image") |
|
self._args = source_info.get("args") |
|
self._partial = self._job_info.get("_partial", False) |
|
self._requirements_file = os.path.join(self._fpath, "requirements.frozen.txt") |
|
self._input_types = TypeRegistry.type_from_dict( |
|
self._job_info.get("input_types") |
|
) |
|
self._output_types = TypeRegistry.type_from_dict( |
|
self._job_info.get("output_types") |
|
) |
|
if self._job_info.get("source_type") == "artifact": |
|
self._set_configure_launch_project(self._configure_launch_project_artifact) |
|
if self._job_info.get("source_type") == "repo": |
|
self._set_configure_launch_project(self._configure_launch_project_repo) |
|
if self._job_info.get("source_type") == "image": |
|
self._set_configure_launch_project(self._configure_launch_project_container) |
|
|
|
@property |
|
def name(self): |
|
return self._name |
|
|
|
def _set_configure_launch_project(self, func): |
|
self.configure_launch_project = func |
|
|
|
def _get_code_artifact(self, artifact_string): |
|
artifact_string, base_url, is_id = util.parse_artifact_string(artifact_string) |
|
if is_id: |
|
code_artifact = wandb.Artifact._from_id(artifact_string, self._api._client) |
|
else: |
|
code_artifact = self._api._artifact(name=artifact_string, type="code") |
|
if code_artifact is None: |
|
raise LaunchError("No code artifact found") |
|
if code_artifact.state == ArtifactState.DELETED: |
|
raise LaunchError( |
|
f"Job {self.name} references deleted code artifact {code_artifact.name}" |
|
) |
|
return code_artifact |
|
|
|
def _configure_launch_project_notebook(self, launch_project): |
|
new_fname = convert_jupyter_notebook_to_script( |
|
self._entrypoint[-1], launch_project.project_dir |
|
) |
|
new_entrypoint = self._entrypoint |
|
new_entrypoint[-1] = new_fname |
|
launch_project.set_job_entry_point(new_entrypoint) |
|
|
|
def _configure_launch_project_repo(self, launch_project): |
|
git_info = self._job_info.get("source", {}).get("git", {}) |
|
_fetch_git_repo( |
|
launch_project.project_dir, |
|
git_info["remote"], |
|
git_info["commit"], |
|
) |
|
if os.path.exists(os.path.join(self._fpath, "diff.patch")): |
|
with open(os.path.join(self._fpath, "diff.patch")) as f: |
|
apply_patch(f.read(), launch_project.project_dir) |
|
shutil.copy(self._requirements_file, launch_project.project_dir) |
|
launch_project.python_version = self._job_info.get("runtime") |
|
if self._notebook_job: |
|
self._configure_launch_project_notebook(launch_project) |
|
else: |
|
launch_project.set_job_entry_point(self._entrypoint) |
|
|
|
if self._dockerfile: |
|
launch_project.set_job_dockerfile(self._dockerfile) |
|
if self._build_context: |
|
launch_project.set_job_build_context(self._build_context) |
|
if self._base_image: |
|
launch_project.set_job_base_image(self._base_image) |
|
|
|
def _configure_launch_project_artifact(self, launch_project): |
|
artifact_string = self._job_info.get("source", {}).get("artifact") |
|
if artifact_string is None: |
|
raise LaunchError(f"Job {self.name} had no source artifact") |
|
|
|
code_artifact = self._get_code_artifact(artifact_string) |
|
launch_project.python_version = self._job_info.get("runtime") |
|
shutil.copy(self._requirements_file, launch_project.project_dir) |
|
|
|
code_artifact.download(launch_project.project_dir) |
|
|
|
if self._notebook_job: |
|
self._configure_launch_project_notebook(launch_project) |
|
else: |
|
launch_project.set_job_entry_point(self._entrypoint) |
|
|
|
if self._dockerfile: |
|
launch_project.set_job_dockerfile(self._dockerfile) |
|
if self._build_context: |
|
launch_project.set_job_build_context(self._build_context) |
|
if self._base_image: |
|
launch_project.set_job_base_image(self._base_image) |
|
|
|
def _configure_launch_project_container(self, launch_project): |
|
launch_project.docker_image = self._job_info.get("source", {}).get("image") |
|
if launch_project.docker_image is None: |
|
raise LaunchError( |
|
"Job had malformed source dictionary without an image key" |
|
) |
|
if self._entrypoint: |
|
launch_project.set_job_entry_point(self._entrypoint) |
|
|
|
def set_entrypoint(self, entrypoint: List[str]): |
|
self._entrypoint = entrypoint |
|
|
|
def call( |
|
self, |
|
config, |
|
project=None, |
|
entity=None, |
|
queue=None, |
|
resource="local-container", |
|
resource_args=None, |
|
template_variables=None, |
|
project_queue=None, |
|
priority=None, |
|
): |
|
from wandb.sdk.launch import _launch_add |
|
|
|
run_config = {} |
|
for key, item in config.items(): |
|
if util._is_artifact_object(item): |
|
if isinstance(item, wandb.Artifact) and item.is_draft(): |
|
raise ValueError("Cannot queue jobs with unlogged artifacts") |
|
run_config[key] = util.artifact_to_json(item) |
|
|
|
run_config.update(config) |
|
|
|
assigned_config_type = self._input_types.assign(run_config) |
|
if self._partial: |
|
wandb.termwarn( |
|
"Launching manually created job for the first time, can't verify types" |
|
) |
|
else: |
|
if isinstance(assigned_config_type, InvalidType): |
|
raise TypeError(self._input_types.explain(run_config)) |
|
|
|
queued_run = _launch_add.launch_add( |
|
job=self._name, |
|
config={"overrides": {"run_config": run_config}}, |
|
template_variables=template_variables, |
|
project=project or self._project, |
|
entity=entity or self._entity, |
|
queue_name=queue, |
|
resource=resource, |
|
project_queue=project_queue, |
|
resource_args=resource_args, |
|
priority=priority, |
|
) |
|
return queued_run |
|
|
|
|
|
class QueuedRun: |
|
"""A single queued run associated with an entity and project. Call `run = queued_run.wait_until_running()` or `run = queued_run.wait_until_finished()` to access the run.""" |
|
|
|
def __init__( |
|
self, |
|
client, |
|
entity, |
|
project, |
|
queue_name, |
|
run_queue_item_id, |
|
project_queue=LAUNCH_DEFAULT_PROJECT, |
|
priority=None, |
|
): |
|
self.client = client |
|
self._entity = entity |
|
self._project = project |
|
self._queue_name = queue_name |
|
self._run_queue_item_id = run_queue_item_id |
|
self.sweep = None |
|
self._run = None |
|
self.project_queue = project_queue |
|
self.priority = priority |
|
|
|
@property |
|
def queue_name(self): |
|
return self._queue_name |
|
|
|
@property |
|
def id(self): |
|
return self._run_queue_item_id |
|
|
|
@property |
|
def project(self): |
|
return self._project |
|
|
|
@property |
|
def entity(self): |
|
return self._entity |
|
|
|
@property |
|
def state(self): |
|
item = self._get_item() |
|
if item: |
|
return item["state"].lower() |
|
|
|
raise ValueError( |
|
f"Could not find QueuedRunItem associated with id: {self.id} on queue {self.queue_name} at itemId: {self.id}" |
|
) |
|
|
|
@normalize_exceptions |
|
def _get_run_queue_item_legacy(self) -> Dict: |
|
query = gql( |
|
""" |
|
query GetRunQueueItem($projectName: String!, $entityName: String!, $runQueue: String!) { |
|
project(name: $projectName, entityName: $entityName) { |
|
runQueue(name:$runQueue) { |
|
runQueueItems { |
|
edges { |
|
node { |
|
id |
|
state |
|
associatedRunId |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
variable_values = { |
|
"projectName": self.project_queue, |
|
"entityName": self._entity, |
|
"runQueue": self.queue_name, |
|
} |
|
res = self.client.execute(query, variable_values) |
|
|
|
for item in res["project"]["runQueue"]["runQueueItems"]["edges"]: |
|
if str(item["node"]["id"]) == str(self.id): |
|
return item["node"] |
|
|
|
@normalize_exceptions |
|
def _get_item(self): |
|
query = gql( |
|
""" |
|
query GetRunQueueItem($projectName: String!, $entityName: String!, $runQueue: String!, $itemId: ID!) { |
|
project(name: $projectName, entityName: $entityName) { |
|
runQueue(name: $runQueue) { |
|
runQueueItem(id: $itemId) { |
|
id |
|
state |
|
associatedRunId |
|
} |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
variable_values = { |
|
"projectName": self.project_queue, |
|
"entityName": self._entity, |
|
"runQueue": self.queue_name, |
|
"itemId": self.id, |
|
} |
|
try: |
|
res = self.client.execute(query, variable_values) |
|
if res["project"]["runQueue"].get("runQueueItem") is not None: |
|
return res["project"]["runQueue"]["runQueueItem"] |
|
except Exception as e: |
|
if "Cannot query field" not in str(e): |
|
raise LaunchError(f"Unknown exception: {e}") |
|
|
|
return self._get_run_queue_item_legacy() |
|
|
|
@normalize_exceptions |
|
def wait_until_finished(self): |
|
if not self._run: |
|
self.wait_until_running() |
|
|
|
self._run.wait_until_finished() |
|
|
|
self._run.load(force=True) |
|
return self._run |
|
|
|
@normalize_exceptions |
|
def delete(self, delete_artifacts=False): |
|
"""Delete the given queued run from the wandb backend.""" |
|
query = gql( |
|
""" |
|
query fetchRunQueuesFromProject($entityName: String!, $projectName: String!, $runQueueName: String!) { |
|
project(name: $projectName, entityName: $entityName) { |
|
runQueue(name: $runQueueName) { |
|
id |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
|
|
res = self.client.execute( |
|
query, |
|
variable_values={ |
|
"entityName": self.entity, |
|
"projectName": self.project_queue, |
|
"runQueueName": self.queue_name, |
|
}, |
|
) |
|
|
|
if res["project"].get("runQueue") is not None: |
|
queue_id = res["project"]["runQueue"]["id"] |
|
|
|
mutation = gql( |
|
""" |
|
mutation DeleteFromRunQueue( |
|
$queueID: ID!, |
|
$runQueueItemId: ID! |
|
) { |
|
deleteFromRunQueue(input: { |
|
queueID: $queueID |
|
runQueueItemId: $runQueueItemId |
|
}) { |
|
success |
|
clientMutationId |
|
} |
|
} |
|
""" |
|
) |
|
self.client.execute( |
|
mutation, |
|
variable_values={ |
|
"queueID": queue_id, |
|
"runQueueItemId": self._run_queue_item_id, |
|
}, |
|
) |
|
|
|
@normalize_exceptions |
|
def wait_until_running(self): |
|
if self._run is not None: |
|
return self._run |
|
|
|
while True: |
|
|
|
time.sleep(2) |
|
item = self._get_item() |
|
if item and item["associatedRunId"] is not None: |
|
try: |
|
self._run = public.Run( |
|
self.client, |
|
self._entity, |
|
self.project, |
|
item["associatedRunId"], |
|
None, |
|
) |
|
self._run_id = item["associatedRunId"] |
|
except ValueError as e: |
|
wandb.termwarn(str(e)) |
|
else: |
|
return self._run |
|
elif item: |
|
wandb.termlog("Waiting for run to start") |
|
|
|
time.sleep(3) |
|
|
|
def __repr__(self): |
|
return f"<QueuedRun {self.queue_name} ({self.id})" |
|
|
|
|
|
RunQueueResourceType = Literal[ |
|
"local-container", "local-process", "kubernetes", "sagemaker", "gcp-vertex" |
|
] |
|
RunQueueAccessType = Literal["project", "user"] |
|
RunQueuePrioritizationMode = Literal["DISABLED", "V0"] |
|
|
|
|
|
class RunQueue: |
|
def __init__( |
|
self, |
|
client: "RetryingClient", |
|
name: str, |
|
entity: str, |
|
prioritization_mode: Optional[RunQueuePrioritizationMode] = None, |
|
_access: Optional[RunQueueAccessType] = None, |
|
_default_resource_config_id: Optional[int] = None, |
|
_default_resource_config: Optional[dict] = None, |
|
) -> None: |
|
self._name: str = name |
|
self._client = client |
|
self._entity = entity |
|
self._prioritization_mode = prioritization_mode |
|
self._access = _access |
|
self._default_resource_config_id = _default_resource_config_id |
|
self._default_resource_config = _default_resource_config |
|
self._template_variables = None |
|
self._type = None |
|
self._items = None |
|
self._id = None |
|
|
|
@property |
|
def name(self): |
|
return self._name |
|
|
|
@property |
|
def entity(self): |
|
return self._entity |
|
|
|
@property |
|
def prioritization_mode(self) -> RunQueuePrioritizationMode: |
|
if self._prioritization_mode is None: |
|
self._get_metadata() |
|
return self._prioritization_mode |
|
|
|
@property |
|
def access(self) -> RunQueueAccessType: |
|
if self._access is None: |
|
self._get_metadata() |
|
return self._access |
|
|
|
@property |
|
def external_links(self) -> Dict[str, str]: |
|
if self._external_links is None: |
|
self._get_metadata() |
|
return self._external_links |
|
|
|
@property |
|
def type(self) -> RunQueueResourceType: |
|
if self._type is None: |
|
if self._default_resource_config_id is None: |
|
self._get_metadata() |
|
self._get_default_resource_config() |
|
return self._type |
|
|
|
@property |
|
def default_resource_config(self): |
|
if self._default_resource_config is None: |
|
if self._default_resource_config_id is None: |
|
self._get_metadata() |
|
self._get_default_resource_config() |
|
return self._default_resource_config |
|
|
|
@property |
|
def template_variables(self): |
|
if self._template_variables is None: |
|
if self._default_resource_config_id is None: |
|
self._get_metadata() |
|
self._get_default_resource_config() |
|
return self._template_variables |
|
|
|
@property |
|
def id(self) -> str: |
|
if self._id is None: |
|
self._get_metadata() |
|
return self._id |
|
|
|
@property |
|
def items(self) -> List[QueuedRun]: |
|
"""Up to the first 100 queued runs. Modifying this list will not modify the queue or any enqueued items!""" |
|
|
|
if self._items is None: |
|
self._get_items() |
|
return self._items |
|
|
|
@normalize_exceptions |
|
def delete(self): |
|
"""Delete the run queue from the wandb backend.""" |
|
query = gql( |
|
""" |
|
mutation DeleteRunQueue($id: ID!) { |
|
deleteRunQueues(input: {queueIDs: [$id]}) { |
|
success |
|
clientMutationId |
|
} |
|
} |
|
""" |
|
) |
|
variable_values = {"id": self.id} |
|
res = self._client.execute(query, variable_values) |
|
if res["deleteRunQueues"]["success"]: |
|
self._id = None |
|
self._access = None |
|
self._default_resource_config_id = None |
|
self._default_resource_config = None |
|
self._items = None |
|
else: |
|
raise CommError(f"Failed to delete run queue {self.name}") |
|
|
|
def __repr__(self): |
|
return f"<RunQueue {self._entity}/{self._name}>" |
|
|
|
@normalize_exceptions |
|
def _get_metadata(self): |
|
query = gql( |
|
""" |
|
query GetRunQueueMetadata($projectName: String!, $entityName: String!, $runQueue: String!) { |
|
project(name: $projectName, entityName: $entityName) { |
|
runQueue(name: $runQueue) { |
|
id |
|
access |
|
defaultResourceConfigID |
|
prioritizationMode |
|
externalLinks |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
variable_values = { |
|
"projectName": LAUNCH_DEFAULT_PROJECT, |
|
"entityName": self._entity, |
|
"runQueue": self._name, |
|
} |
|
res = self._client.execute(query, variable_values) |
|
self._id = res["project"]["runQueue"]["id"] |
|
self._access = res["project"]["runQueue"]["access"] |
|
self._default_resource_config_id = res["project"]["runQueue"][ |
|
"defaultResourceConfigID" |
|
] |
|
self._external_links = res["project"]["runQueue"]["externalLinks"] |
|
if self._default_resource_config_id is None: |
|
self._default_resource_config = {} |
|
self._prioritization_mode = res["project"]["runQueue"]["prioritizationMode"] |
|
|
|
@normalize_exceptions |
|
def _get_default_resource_config(self): |
|
query = gql( |
|
""" |
|
query GetDefaultResourceConfig($entityName: String!, $id: ID!) { |
|
entity(name: $entityName) { |
|
defaultResourceConfig(id: $id) { |
|
config |
|
resource |
|
templateVariables { |
|
name |
|
schema |
|
} |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
variable_values = { |
|
"entityName": self._entity, |
|
"id": self._default_resource_config_id, |
|
} |
|
res = self._client.execute(query, variable_values) |
|
self._type = res["entity"]["defaultResourceConfig"]["resource"] |
|
self._default_resource_config = res["entity"]["defaultResourceConfig"]["config"] |
|
self._template_variables = res["entity"]["defaultResourceConfig"][ |
|
"templateVariables" |
|
] |
|
|
|
@normalize_exceptions |
|
def _get_items(self): |
|
query = gql( |
|
""" |
|
query GetRunQueueItems($projectName: String!, $entityName: String!, $runQueue: String!) { |
|
project(name: $projectName, entityName: $entityName) { |
|
runQueue(name: $runQueue) { |
|
runQueueItems(first: 100) { |
|
edges { |
|
node { |
|
id |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
variable_values = { |
|
"projectName": LAUNCH_DEFAULT_PROJECT, |
|
"entityName": self._entity, |
|
"runQueue": self._name, |
|
} |
|
res = self._client.execute(query, variable_values) |
|
self._items = [] |
|
for item in res["project"]["runQueue"]["runQueueItems"]["edges"]: |
|
self._items.append( |
|
QueuedRun( |
|
self._client, |
|
self._entity, |
|
LAUNCH_DEFAULT_PROJECT, |
|
self._name, |
|
item["node"]["id"], |
|
) |
|
) |
|
|
|
@classmethod |
|
def create( |
|
cls, |
|
name: str, |
|
resource: "RunQueueResourceType", |
|
entity: Optional[str] = None, |
|
prioritization_mode: Optional["RunQueuePrioritizationMode"] = None, |
|
config: Optional[dict] = None, |
|
template_variables: Optional[dict] = None, |
|
) -> "RunQueue": |
|
public_api = Api() |
|
return public_api.create_run_queue( |
|
name, resource, entity, prioritization_mode, config, template_variables |
|
) |
|
|