|
"""Public API: runs.""" |
|
|
|
import json |
|
import os |
|
import tempfile |
|
import time |
|
import urllib |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
Collection, |
|
Dict, |
|
List, |
|
Literal, |
|
Mapping, |
|
Optional, |
|
) |
|
|
|
from wandb_gql import gql |
|
|
|
import wandb |
|
from wandb import env, util |
|
from wandb.apis import public |
|
from wandb.apis.attrs import Attrs |
|
from wandb.apis.internal import Api as InternalApi |
|
from wandb.apis.normalize import normalize_exceptions |
|
from wandb.apis.paginator import SizedPaginator |
|
from wandb.apis.public.const import RETRY_TIMEDELTA |
|
from wandb.sdk.lib import ipython, json_util, runid |
|
from wandb.sdk.lib.paths import LogicalPath |
|
|
|
if TYPE_CHECKING: |
|
from wandb.apis.public import RetryingClient |
|
|
|
WANDB_INTERNAL_KEYS = {"_wandb", "wandb_version"} |
|
|
|
RUN_FRAGMENT = """fragment RunFragment on Run { |
|
id |
|
tags |
|
name |
|
displayName |
|
sweepName |
|
state |
|
config |
|
group |
|
jobType |
|
commit |
|
readOnly |
|
createdAt |
|
heartbeatAt |
|
description |
|
notes |
|
systemMetrics |
|
summaryMetrics |
|
historyLineCount |
|
user { |
|
name |
|
username |
|
} |
|
historyKeys |
|
}""" |
|
|
|
|
|
@normalize_exceptions |
|
def _server_provides_internal_id_for_project(client) -> bool: |
|
"""Returns True if the server allows us to query the internalId field for a project. |
|
|
|
This check is done by utilizing GraphQL introspection in the available fields on the Project type. |
|
""" |
|
query_string = """ |
|
query ProbeRunInput { |
|
RunType: __type(name:"Run") { |
|
fields { |
|
name |
|
} |
|
} |
|
} |
|
""" |
|
|
|
|
|
query = gql(query_string) |
|
res = client.execute(query) |
|
return "projectId" in [ |
|
x["name"] for x in (res.get("RunType", {}).get("fields", [{}])) |
|
] |
|
|
|
|
|
class Runs(SizedPaginator["Run"]): |
|
"""An iterable collection of runs associated with a project and optional filter. |
|
|
|
This is generally used indirectly via the `Api`.runs method. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
client: "RetryingClient", |
|
entity: str, |
|
project: str, |
|
filters: Optional[Dict[str, Any]] = None, |
|
order: Optional[str] = None, |
|
per_page: int = 50, |
|
include_sweeps: bool = True, |
|
): |
|
self.QUERY = gql( |
|
f"""#graphql |
|
query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{ |
|
project(name: $project, entityName: $entity) {{ |
|
internalId |
|
runCount(filters: $filters) |
|
readOnly |
|
runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{ |
|
edges {{ |
|
node {{ |
|
{"" if _server_provides_internal_id_for_project(client) else "internalId"} |
|
...RunFragment |
|
}} |
|
cursor |
|
}} |
|
pageInfo {{ |
|
endCursor |
|
hasNextPage |
|
}} |
|
}} |
|
}} |
|
}} |
|
{RUN_FRAGMENT} |
|
""" |
|
) |
|
|
|
self.entity = entity |
|
self.project = project |
|
self._project_internal_id = None |
|
self.filters = filters or {} |
|
self.order = order |
|
self._sweeps = {} |
|
self._include_sweeps = include_sweeps |
|
variables = { |
|
"project": self.project, |
|
"entity": self.entity, |
|
"order": self.order, |
|
"filters": json.dumps(self.filters), |
|
} |
|
super().__init__(client, variables, per_page) |
|
|
|
@property |
|
def length(self): |
|
if self.last_response: |
|
return self.last_response["project"]["runCount"] |
|
else: |
|
return None |
|
|
|
@property |
|
def more(self): |
|
if self.last_response: |
|
return self.last_response["project"]["runs"]["pageInfo"]["hasNextPage"] |
|
else: |
|
return True |
|
|
|
@property |
|
def cursor(self): |
|
if self.last_response: |
|
return self.last_response["project"]["runs"]["edges"][-1]["cursor"] |
|
else: |
|
return None |
|
|
|
def convert_objects(self): |
|
objs = [] |
|
if self.last_response is None or self.last_response.get("project") is None: |
|
raise ValueError("Could not find project {}".format(self.project)) |
|
for run_response in self.last_response["project"]["runs"]["edges"]: |
|
run = Run( |
|
self.client, |
|
self.entity, |
|
self.project, |
|
run_response["node"]["name"], |
|
run_response["node"], |
|
include_sweeps=self._include_sweeps, |
|
) |
|
objs.append(run) |
|
|
|
if self._include_sweeps and run.sweep_name: |
|
if run.sweep_name in self._sweeps: |
|
sweep = self._sweeps[run.sweep_name] |
|
else: |
|
sweep = public.Sweep.get( |
|
self.client, |
|
self.entity, |
|
self.project, |
|
run.sweep_name, |
|
withRuns=False, |
|
) |
|
self._sweeps[run.sweep_name] = sweep |
|
|
|
if sweep is None: |
|
continue |
|
run.sweep = sweep |
|
|
|
return objs |
|
|
|
@normalize_exceptions |
|
def histories( |
|
self, |
|
samples: int = 500, |
|
keys: Optional[List[str]] = None, |
|
x_axis: str = "_step", |
|
format: Literal["default", "pandas", "polars"] = "default", |
|
stream: Literal["default", "system"] = "default", |
|
): |
|
"""Return sampled history metrics for all runs that fit the filters conditions. |
|
|
|
Args: |
|
samples : (int, optional) The number of samples to return per run |
|
keys : (list[str], optional) Only return metrics for specific keys |
|
x_axis : (str, optional) Use this metric as the xAxis defaults to _step |
|
format : (Literal, optional) Format to return data in, options are "default", "pandas", "polars" |
|
stream : (Literal, optional) "default" for metrics, "system" for machine metrics |
|
Returns: |
|
pandas.DataFrame: If format="pandas", returns a `pandas.DataFrame` of history metrics. |
|
polars.DataFrame: If format="polars", returns a `polars.DataFrame` of history metrics. |
|
list of dicts: If format="default", returns a list of dicts containing history metrics with a run_id key. |
|
""" |
|
if format not in ("default", "pandas", "polars"): |
|
raise ValueError( |
|
f"Invalid format: {format}. Must be one of 'default', 'pandas', 'polars'" |
|
) |
|
|
|
histories = [] |
|
|
|
if format == "default": |
|
for run in self: |
|
history_data = run.history( |
|
samples=samples, |
|
keys=keys, |
|
x_axis=x_axis, |
|
pandas=False, |
|
stream=stream, |
|
) |
|
if not history_data: |
|
continue |
|
for entry in history_data: |
|
entry["run_id"] = run.id |
|
histories.extend(history_data) |
|
|
|
return histories |
|
|
|
if format == "pandas": |
|
pd = util.get_module( |
|
"pandas", required="Exporting pandas DataFrame requires pandas" |
|
) |
|
for run in self: |
|
history_data = run.history( |
|
samples=samples, |
|
keys=keys, |
|
x_axis=x_axis, |
|
pandas=False, |
|
stream=stream, |
|
) |
|
if not history_data: |
|
continue |
|
df = pd.DataFrame.from_records(history_data) |
|
df["run_id"] = run.id |
|
histories.append(df) |
|
if not histories: |
|
return pd.DataFrame() |
|
combined_df = pd.concat(histories) |
|
combined_df.reset_index(drop=True, inplace=True) |
|
|
|
combined_df = combined_df[(sorted(combined_df.columns))] |
|
|
|
return combined_df |
|
|
|
if format == "polars": |
|
pl = util.get_module( |
|
"polars", required="Exporting polars DataFrame requires polars" |
|
) |
|
for run in self: |
|
history_data = run.history( |
|
samples=samples, |
|
keys=keys, |
|
x_axis=x_axis, |
|
pandas=False, |
|
stream=stream, |
|
) |
|
if not history_data: |
|
continue |
|
df = pl.from_records(history_data) |
|
df = df.with_columns(pl.lit(run.id).alias("run_id")) |
|
histories.append(df) |
|
if not histories: |
|
return pl.DataFrame() |
|
combined_df = pl.concat(histories, how="vertical") |
|
|
|
combined_df = combined_df.select(sorted(combined_df.columns)) |
|
|
|
return combined_df |
|
|
|
def __repr__(self): |
|
return f"<Runs {self.entity}/{self.project}>" |
|
|
|
|
|
class Run(Attrs): |
|
"""A single run associated with an entity and project. |
|
|
|
Attributes: |
|
tags ([str]): a list of tags associated with the run |
|
url (str): the url of this run |
|
id (str): unique identifier for the run (defaults to eight characters) |
|
name (str): the name of the run |
|
state (str): one of: running, finished, crashed, killed, preempting, preempted |
|
config (dict): a dict of hyperparameters associated with the run |
|
created_at (str): ISO timestamp when the run was started |
|
system_metrics (dict): the latest system metrics recorded for the run |
|
summary (dict): A mutable dict-like property that holds the current summary. |
|
Calling update will persist any changes. |
|
project (str): the project associated with the run |
|
entity (str): the name of the entity associated with the run |
|
project_internal_id (int): the internal id of the project |
|
user (str): the name of the user who created the run |
|
path (str): Unique identifier [entity]/[project]/[run_id] |
|
notes (str): Notes about the run |
|
read_only (boolean): Whether the run is editable |
|
history_keys (str): Keys of the history metrics that have been logged |
|
with `wandb.log({key: value})` |
|
metadata (str): Metadata about the run from wandb-metadata.json |
|
""" |
|
|
|
def __init__( |
|
self, |
|
client: "RetryingClient", |
|
entity: str, |
|
project: str, |
|
run_id: str, |
|
attrs: Optional[Mapping] = None, |
|
include_sweeps: bool = True, |
|
): |
|
"""Initialize a Run object. |
|
|
|
Run is always initialized by calling api.runs() where api is an instance of |
|
wandb.Api. |
|
""" |
|
_attrs = attrs or {} |
|
super().__init__(dict(_attrs)) |
|
self.client = client |
|
self._entity = entity |
|
self.project = project |
|
self._files = {} |
|
self._base_dir = env.get_dir(tempfile.gettempdir()) |
|
self.id = run_id |
|
self.sweep = None |
|
self._include_sweeps = include_sweeps |
|
self.dir = os.path.join(self._base_dir, *self.path) |
|
try: |
|
os.makedirs(self.dir) |
|
except OSError: |
|
pass |
|
self._summary = None |
|
self._metadata: Optional[Dict[str, Any]] = None |
|
self._state = _attrs.get("state", "not found") |
|
self.server_provides_internal_id_field: Optional[bool] = None |
|
|
|
self.load(force=not _attrs) |
|
|
|
@property |
|
def state(self): |
|
return self._state |
|
|
|
@property |
|
def entity(self): |
|
return self._entity |
|
|
|
@property |
|
def username(self): |
|
wandb.termwarn("Run.username is deprecated. Please use Run.entity instead.") |
|
return self._entity |
|
|
|
@property |
|
def storage_id(self): |
|
|
|
|
|
|
|
return self._attrs.get("id") |
|
|
|
@property |
|
def id(self): |
|
return self._attrs.get("name") |
|
|
|
@id.setter |
|
def id(self, new_id): |
|
attrs = self._attrs |
|
attrs["name"] = new_id |
|
return new_id |
|
|
|
@property |
|
def name(self): |
|
return self._attrs.get("displayName") |
|
|
|
@name.setter |
|
def name(self, new_name): |
|
self._attrs["displayName"] = new_name |
|
return new_name |
|
|
|
@classmethod |
|
def create( |
|
cls, |
|
api, |
|
run_id=None, |
|
project=None, |
|
entity=None, |
|
state: Literal["running", "pending"] = "running", |
|
): |
|
"""Create a run for the given project.""" |
|
run_id = run_id or runid.generate_id() |
|
project = project or api.settings.get("project") or "uncategorized" |
|
mutation = gql( |
|
""" |
|
mutation UpsertBucket($project: String, $entity: String, $name: String!, $state: String) { |
|
upsertBucket(input: {modelName: $project, entityName: $entity, name: $name, state: $state}) { |
|
bucket { |
|
project { |
|
name |
|
entity { name } |
|
} |
|
id |
|
name |
|
} |
|
inserted |
|
} |
|
} |
|
""" |
|
) |
|
variables = { |
|
"entity": entity, |
|
"project": project, |
|
"name": run_id, |
|
"state": state, |
|
} |
|
res = api.client.execute(mutation, variable_values=variables) |
|
res = res["upsertBucket"]["bucket"] |
|
return Run( |
|
api.client, |
|
res["project"]["entity"]["name"], |
|
res["project"]["name"], |
|
res["name"], |
|
{ |
|
"id": res["id"], |
|
"config": "{}", |
|
"systemMetrics": "{}", |
|
"summaryMetrics": "{}", |
|
"tags": [], |
|
"description": None, |
|
"notes": None, |
|
"state": state, |
|
}, |
|
) |
|
|
|
def load(self, force=False): |
|
query = gql( |
|
""" |
|
query Run($project: String!, $entity: String!, $name: String!) {{ |
|
project(name: $project, entityName: $entity) {{ |
|
run(name: $name) {{ |
|
{} |
|
...RunFragment |
|
}} |
|
}} |
|
}} |
|
{} |
|
""".format( |
|
"projectId" |
|
if _server_provides_internal_id_for_project(self.client) |
|
else "", |
|
RUN_FRAGMENT, |
|
) |
|
) |
|
if force or not self._attrs: |
|
response = self._exec(query) |
|
if ( |
|
response is None |
|
or response.get("project") is None |
|
or response["project"].get("run") is None |
|
): |
|
raise ValueError("Could not find run {}".format(self)) |
|
self._attrs = response["project"]["run"] |
|
self._state = self._attrs["state"] |
|
|
|
if self._include_sweeps and self.sweep_name and not self.sweep: |
|
|
|
|
|
self.sweep = public.Sweep.get( |
|
self.client, |
|
self.entity, |
|
self.project, |
|
self.sweep_name, |
|
withRuns=False, |
|
) |
|
|
|
if "projectId" in self._attrs: |
|
self._project_internal_id = int(self._attrs["projectId"]) |
|
else: |
|
self._project_internal_id = None |
|
|
|
try: |
|
self._attrs["summaryMetrics"] = ( |
|
json.loads(self._attrs["summaryMetrics"]) |
|
if self._attrs.get("summaryMetrics") |
|
else {} |
|
) |
|
except json.decoder.JSONDecodeError: |
|
|
|
self._attrs["summaryMetrics"] = json.loads( |
|
self._attrs["summaryMetrics"], |
|
strict=False, |
|
) |
|
self._attrs["systemMetrics"] = ( |
|
json.loads(self._attrs["systemMetrics"]) |
|
if self._attrs.get("systemMetrics") |
|
else {} |
|
) |
|
if self._attrs.get("user"): |
|
self.user = public.User(self.client, self._attrs["user"]) |
|
config_user, config_raw = {}, {} |
|
for key, value in json.loads(self._attrs.get("config") or "{}").items(): |
|
config = config_raw if key in WANDB_INTERNAL_KEYS else config_user |
|
if isinstance(value, dict) and "value" in value: |
|
config[key] = value["value"] |
|
else: |
|
config[key] = value |
|
config_raw.update(config_user) |
|
self._attrs["config"] = config_user |
|
self._attrs["rawconfig"] = config_raw |
|
return self._attrs |
|
|
|
@normalize_exceptions |
|
def wait_until_finished(self): |
|
query = gql( |
|
""" |
|
query RunState($project: String!, $entity: String!, $name: String!) { |
|
project(name: $project, entityName: $entity) { |
|
run(name: $name) { |
|
state |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
while True: |
|
res = self._exec(query) |
|
state = res["project"]["run"]["state"] |
|
if state in ["finished", "crashed", "failed"]: |
|
self._attrs["state"] = state |
|
self._state = state |
|
return |
|
time.sleep(5) |
|
|
|
@normalize_exceptions |
|
def update(self): |
|
"""Persist changes to the run object to the wandb backend.""" |
|
mutation = gql( |
|
""" |
|
mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String, $jobType: String) {{ |
|
upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName, jobType: $jobType}}) {{ |
|
bucket {{ |
|
...RunFragment |
|
}} |
|
}} |
|
}} |
|
{} |
|
""".format(RUN_FRAGMENT) |
|
) |
|
_ = self._exec( |
|
mutation, |
|
id=self.storage_id, |
|
tags=self.tags, |
|
description=self.description, |
|
notes=self.notes, |
|
display_name=self.display_name, |
|
config=self.json_config, |
|
groupName=self.group, |
|
jobType=self.job_type, |
|
) |
|
self.summary.update() |
|
|
|
@normalize_exceptions |
|
def delete(self, delete_artifacts=False): |
|
"""Delete the given run from the wandb backend.""" |
|
mutation = gql( |
|
""" |
|
mutation DeleteRun( |
|
$id: ID!, |
|
{} |
|
) {{ |
|
deleteRun(input: {{ |
|
id: $id, |
|
{} |
|
}}) {{ |
|
clientMutationId |
|
}} |
|
}} |
|
""".format( |
|
"$deleteArtifacts: Boolean" if delete_artifacts else "", |
|
"deleteArtifacts: $deleteArtifacts" if delete_artifacts else "", |
|
) |
|
) |
|
|
|
self.client.execute( |
|
mutation, |
|
variable_values={ |
|
"id": self.storage_id, |
|
"deleteArtifacts": delete_artifacts, |
|
}, |
|
) |
|
|
|
def save(self): |
|
self.update() |
|
|
|
@property |
|
def json_config(self): |
|
config = {} |
|
if "_wandb" in self.rawconfig: |
|
config["_wandb"] = {"value": self.rawconfig["_wandb"], "desc": None} |
|
for k, v in self.config.items(): |
|
config[k] = {"value": v, "desc": None} |
|
return json.dumps(config) |
|
|
|
def _exec(self, query, **kwargs): |
|
"""Execute a query against the cloud backend.""" |
|
variables = {"entity": self.entity, "project": self.project, "name": self.id} |
|
variables.update(kwargs) |
|
return self.client.execute(query, variable_values=variables) |
|
|
|
def _sampled_history(self, keys, x_axis="_step", samples=500): |
|
spec = {"keys": [x_axis] + keys, "samples": samples} |
|
query = gql( |
|
""" |
|
query RunSampledHistory($project: String!, $entity: String!, $name: String!, $specs: [JSONString!]!) { |
|
project(name: $project, entityName: $entity) { |
|
run(name: $name) { sampledHistory(specs: $specs) } |
|
} |
|
} |
|
""" |
|
) |
|
|
|
response = self._exec(query, specs=[json.dumps(spec)]) |
|
|
|
return response["project"]["run"]["sampledHistory"][0] |
|
|
|
def _full_history(self, samples=500, stream="default"): |
|
node = "history" if stream == "default" else "events" |
|
query = gql( |
|
""" |
|
query RunFullHistory($project: String!, $entity: String!, $name: String!, $samples: Int) {{ |
|
project(name: $project, entityName: $entity) {{ |
|
run(name: $name) {{ {}(samples: $samples) }} |
|
}} |
|
}} |
|
""".format(node) |
|
) |
|
|
|
response = self._exec(query, samples=samples) |
|
return [json.loads(line) for line in response["project"]["run"][node]] |
|
|
|
@normalize_exceptions |
|
def files(self, names=None, per_page=50): |
|
"""Return a file path for each file named. |
|
|
|
Args: |
|
names (list): names of the requested files, if empty returns all files |
|
per_page (int): number of results per page. |
|
|
|
Returns: |
|
A `Files` object, which is an iterator over `File` objects. |
|
""" |
|
return public.Files(self.client, self, names or [], per_page) |
|
|
|
@normalize_exceptions |
|
def file(self, name): |
|
"""Return the path of a file with a given name in the artifact. |
|
|
|
Args: |
|
name (str): name of requested file. |
|
|
|
Returns: |
|
A `File` matching the name argument. |
|
""" |
|
return public.Files(self.client, self, [name])[0] |
|
|
|
@normalize_exceptions |
|
def upload_file(self, path, root="."): |
|
"""Upload a file. |
|
|
|
Args: |
|
path (str): name of file to upload. |
|
root (str): the root path to save the file relative to. i.e. |
|
If you want to have the file saved in the run as "my_dir/file.txt" |
|
and you're currently in "my_dir" you would set root to "../". |
|
|
|
Returns: |
|
A `File` matching the name argument. |
|
""" |
|
api = InternalApi( |
|
default_settings={"entity": self.entity, "project": self.project}, |
|
retry_timedelta=RETRY_TIMEDELTA, |
|
) |
|
api.set_current_run_id(self.id) |
|
root = os.path.abspath(root) |
|
name = os.path.relpath(path, root) |
|
with open(os.path.join(root, name), "rb") as f: |
|
api.push({LogicalPath(name): f}) |
|
return public.Files(self.client, self, [name])[0] |
|
|
|
@normalize_exceptions |
|
def history( |
|
self, samples=500, keys=None, x_axis="_step", pandas=True, stream="default" |
|
): |
|
"""Return sampled history metrics for a run. |
|
|
|
This is simpler and faster if you are ok with the history records being sampled. |
|
|
|
Args: |
|
samples : (int, optional) The number of samples to return |
|
pandas : (bool, optional) Return a pandas dataframe |
|
keys : (list, optional) Only return metrics for specific keys |
|
x_axis : (str, optional) Use this metric as the xAxis defaults to _step |
|
stream : (str, optional) "default" for metrics, "system" for machine metrics |
|
|
|
Returns: |
|
pandas.DataFrame: If pandas=True returns a `pandas.DataFrame` of history |
|
metrics. |
|
list of dicts: If pandas=False returns a list of dicts of history metrics. |
|
""" |
|
if keys is not None and not isinstance(keys, list): |
|
wandb.termerror("keys must be specified in a list") |
|
return [] |
|
if keys is not None and len(keys) > 0 and not isinstance(keys[0], str): |
|
wandb.termerror("keys argument must be a list of strings") |
|
return [] |
|
|
|
if keys and stream != "default": |
|
wandb.termerror("stream must be default when specifying keys") |
|
return [] |
|
elif keys: |
|
lines = self._sampled_history(keys=keys, x_axis=x_axis, samples=samples) |
|
else: |
|
lines = self._full_history(samples=samples, stream=stream) |
|
if pandas: |
|
pd = util.get_module("pandas") |
|
if pd: |
|
lines = pd.DataFrame.from_records(lines) |
|
else: |
|
wandb.termwarn("Unable to load pandas, call history with pandas=False") |
|
return lines |
|
|
|
@normalize_exceptions |
|
def scan_history(self, keys=None, page_size=1000, min_step=None, max_step=None): |
|
"""Returns an iterable collection of all history records for a run. |
|
|
|
Example: |
|
Export all the loss values for an example run |
|
|
|
```python |
|
run = api.run("l2k2/examples-numpy-boston/i0wt6xua") |
|
history = run.scan_history(keys=["Loss"]) |
|
losses = [row["Loss"] for row in history] |
|
``` |
|
|
|
Args: |
|
keys ([str], optional): only fetch these keys, and only fetch rows that have all of keys defined. |
|
page_size (int, optional): size of pages to fetch from the api. |
|
min_step (int, optional): the minimum number of pages to scan at a time. |
|
max_step (int, optional): the maximum number of pages to scan at a time. |
|
|
|
Returns: |
|
An iterable collection over history records (dict). |
|
""" |
|
if keys is not None and not isinstance(keys, list): |
|
wandb.termerror("keys must be specified in a list") |
|
return [] |
|
if keys is not None and len(keys) > 0 and not isinstance(keys[0], str): |
|
wandb.termerror("keys argument must be a list of strings") |
|
return [] |
|
|
|
last_step = self.lastHistoryStep |
|
|
|
if min_step is None: |
|
min_step = 0 |
|
if max_step is None: |
|
max_step = last_step + 1 |
|
|
|
if max_step > last_step: |
|
max_step = last_step + 1 |
|
if keys is None: |
|
return public.HistoryScan( |
|
run=self, |
|
client=self.client, |
|
page_size=page_size, |
|
min_step=min_step, |
|
max_step=max_step, |
|
) |
|
else: |
|
return public.SampledHistoryScan( |
|
run=self, |
|
client=self.client, |
|
keys=keys, |
|
page_size=page_size, |
|
min_step=min_step, |
|
max_step=max_step, |
|
) |
|
|
|
@normalize_exceptions |
|
def logged_artifacts(self, per_page: int = 100) -> public.RunArtifacts: |
|
"""Fetches all artifacts logged by this run. |
|
|
|
Retrieves all output artifacts that were logged during the run. Returns a |
|
paginated result that can be iterated over or collected into a single list. |
|
|
|
Args: |
|
per_page: Number of artifacts to fetch per API request. |
|
|
|
Returns: |
|
An iterable collection of all Artifact objects logged as outputs during this run. |
|
|
|
Example: |
|
>>> import wandb |
|
>>> import tempfile |
|
>>> with tempfile.NamedTemporaryFile( |
|
... mode="w", delete=False, suffix=".txt" |
|
... ) as tmp: |
|
... tmp.write("This is a test artifact") |
|
... tmp_path = tmp.name |
|
>>> run = wandb.init(project="artifact-example") |
|
>>> artifact = wandb.Artifact("test_artifact", type="dataset") |
|
>>> artifact.add_file(tmp_path) |
|
>>> run.log_artifact(artifact) |
|
>>> run.finish() |
|
>>> api = wandb.Api() |
|
>>> finished_run = api.run(f"{run.entity}/{run.project}/{run.id}") |
|
>>> for logged_artifact in finished_run.logged_artifacts(): |
|
... print(logged_artifact.name) |
|
test_artifact |
|
""" |
|
return public.RunArtifacts(self.client, self, mode="logged", per_page=per_page) |
|
|
|
@normalize_exceptions |
|
def used_artifacts(self, per_page: int = 100) -> public.RunArtifacts: |
|
"""Fetches artifacts explicitly used by this run. |
|
|
|
Retrieves only the input artifacts that were explicitly declared as used |
|
during the run, typically via `run.use_artifact()`. Returns a paginated |
|
result that can be iterated over or collected into a single list. |
|
|
|
Args: |
|
per_page: Number of artifacts to fetch per API request. |
|
|
|
Returns: |
|
An iterable collection of Artifact objects explicitly used as inputs in this run. |
|
|
|
Example: |
|
>>> import wandb |
|
>>> run = wandb.init(project="artifact-example") |
|
>>> run.use_artifact("test_artifact:latest") |
|
>>> run.finish() |
|
>>> api = wandb.Api() |
|
>>> finished_run = api.run(f"{run.entity}/{run.project}/{run.id}") |
|
>>> for used_artifact in finished_run.used_artifacts(): |
|
... print(used_artifact.name) |
|
test_artifact |
|
""" |
|
return public.RunArtifacts(self.client, self, mode="used", per_page=per_page) |
|
|
|
@normalize_exceptions |
|
def use_artifact(self, artifact, use_as=None): |
|
"""Declare an artifact as an input to a run. |
|
|
|
Args: |
|
artifact (`Artifact`): An artifact returned from |
|
`wandb.Api().artifact(name)` |
|
use_as (string, optional): A string identifying |
|
how the artifact is used in the script. Used |
|
to easily differentiate artifacts used in a |
|
run, when using the beta wandb launch |
|
feature's artifact swapping functionality. |
|
|
|
Returns: |
|
A `Artifact` object. |
|
""" |
|
api = InternalApi( |
|
default_settings={"entity": self.entity, "project": self.project}, |
|
retry_timedelta=RETRY_TIMEDELTA, |
|
) |
|
api.set_current_run_id(self.id) |
|
|
|
if isinstance(artifact, wandb.Artifact) and not artifact.is_draft(): |
|
api.use_artifact( |
|
artifact.id, |
|
use_as=use_as or artifact.name, |
|
artifact_entity_name=artifact.entity, |
|
artifact_project_name=artifact.project, |
|
) |
|
return artifact |
|
elif isinstance(artifact, wandb.Artifact) and artifact.is_draft(): |
|
raise ValueError( |
|
"Only existing artifacts are accepted by this api. " |
|
"Manually create one with `wandb artifact put`" |
|
) |
|
else: |
|
raise ValueError("You must pass a wandb.Api().artifact() to use_artifact") |
|
|
|
@normalize_exceptions |
|
def log_artifact( |
|
self, |
|
artifact: "wandb.Artifact", |
|
aliases: Optional[Collection[str]] = None, |
|
tags: Optional[Collection[str]] = None, |
|
): |
|
"""Declare an artifact as output of a run. |
|
|
|
Args: |
|
artifact (`Artifact`): An artifact returned from |
|
`wandb.Api().artifact(name)`. |
|
aliases (list, optional): Aliases to apply to this artifact. |
|
tags: (list, optional) Tags to apply to this artifact, if any. |
|
|
|
Returns: |
|
A `Artifact` object. |
|
""" |
|
api = InternalApi( |
|
default_settings={"entity": self.entity, "project": self.project}, |
|
retry_timedelta=RETRY_TIMEDELTA, |
|
) |
|
api.set_current_run_id(self.id) |
|
|
|
if not isinstance(artifact, wandb.Artifact): |
|
raise TypeError("You must pass a wandb.Api().artifact() to use_artifact") |
|
if artifact.is_draft(): |
|
raise ValueError( |
|
"Only existing artifacts are accepted by this api. " |
|
"Manually create one with `wandb artifact put`" |
|
) |
|
if ( |
|
self.entity != artifact.source_entity |
|
or self.project != artifact.source_project |
|
): |
|
raise ValueError("A run can't log an artifact to a different project.") |
|
|
|
artifact_collection_name = artifact.source_name.split(":")[0] |
|
api.create_artifact( |
|
artifact.type, |
|
artifact_collection_name, |
|
artifact.digest, |
|
aliases=aliases, |
|
tags=tags, |
|
) |
|
return artifact |
|
|
|
@property |
|
def summary(self): |
|
if self._summary is None: |
|
from wandb.old.summary import HTTPSummary |
|
|
|
|
|
self._summary = HTTPSummary(self, self.client, summary=self.summary_metrics) |
|
return self._summary |
|
|
|
@property |
|
def path(self): |
|
return [ |
|
urllib.parse.quote_plus(str(self.entity)), |
|
urllib.parse.quote_plus(str(self.project)), |
|
urllib.parse.quote_plus(str(self.id)), |
|
] |
|
|
|
@property |
|
def url(self): |
|
path = self.path |
|
path.insert(2, "runs") |
|
return self.client.app_url + "/".join(path) |
|
|
|
@property |
|
def metadata(self): |
|
if self._metadata is None: |
|
try: |
|
f = self.file("wandb-metadata.json") |
|
session = self.client._client.transport.session |
|
response = session.get(f.url, timeout=5) |
|
response.raise_for_status() |
|
contents = response.content |
|
self._metadata = json_util.loads(contents) |
|
except: |
|
|
|
pass |
|
return self._metadata |
|
|
|
@property |
|
def lastHistoryStep(self): |
|
query = gql( |
|
""" |
|
query RunHistoryKeys($project: String!, $entity: String!, $name: String!) { |
|
project(name: $project, entityName: $entity) { |
|
run(name: $name) { historyKeys } |
|
} |
|
} |
|
""" |
|
) |
|
response = self._exec(query) |
|
if ( |
|
response is None |
|
or response.get("project") is None |
|
or response["project"].get("run") is None |
|
or response["project"]["run"].get("historyKeys") is None |
|
): |
|
return -1 |
|
history_keys = response["project"]["run"]["historyKeys"] |
|
return history_keys["lastStep"] if "lastStep" in history_keys else -1 |
|
|
|
def to_html(self, height=420, hidden=False): |
|
"""Generate HTML containing an iframe displaying this run.""" |
|
url = self.url + "?jupyter=true" |
|
style = f"border:none;width:100%;height:{height}px;" |
|
prefix = "" |
|
if hidden: |
|
style += "display:none;" |
|
prefix = ipython.toggle_button() |
|
return prefix + f"<iframe src={url!r} style={style!r}></iframe>" |
|
|
|
def _repr_html_(self) -> str: |
|
return self.to_html() |
|
|
|
def __repr__(self): |
|
return "<Run {} ({})>".format("/".join(self.path), self.state) |
|
|