"""Public API: runs.""" import json import os import tempfile import time import urllib from typing import ( TYPE_CHECKING, Any, Collection, Dict, List, Literal, Mapping, Optional, ) from wandb_gql import gql import wandb from wandb import env, util from wandb.apis import public from wandb.apis.attrs import Attrs from wandb.apis.internal import Api as InternalApi from wandb.apis.normalize import normalize_exceptions from wandb.apis.paginator import SizedPaginator from wandb.apis.public.const import RETRY_TIMEDELTA from wandb.sdk.lib import ipython, json_util, runid from wandb.sdk.lib.paths import LogicalPath if TYPE_CHECKING: from wandb.apis.public import RetryingClient WANDB_INTERNAL_KEYS = {"_wandb", "wandb_version"} RUN_FRAGMENT = """fragment RunFragment on Run { id tags name displayName sweepName state config group jobType commit readOnly createdAt heartbeatAt description notes systemMetrics summaryMetrics historyLineCount user { name username } historyKeys }""" @normalize_exceptions def _server_provides_internal_id_for_project(client) -> bool: """Returns True if the server allows us to query the internalId field for a project. This check is done by utilizing GraphQL introspection in the available fields on the Project type. """ query_string = """ query ProbeRunInput { RunType: __type(name:"Run") { fields { name } } } """ # Only perform the query once to avoid extra network calls query = gql(query_string) res = client.execute(query) return "projectId" in [ x["name"] for x in (res.get("RunType", {}).get("fields", [{}])) ] class Runs(SizedPaginator["Run"]): """An iterable collection of runs associated with a project and optional filter. This is generally used indirectly via the `Api`.runs method. """ def __init__( self, client: "RetryingClient", entity: str, project: str, filters: Optional[Dict[str, Any]] = None, order: Optional[str] = None, per_page: int = 50, include_sweeps: bool = True, ): self.QUERY = gql( f"""#graphql query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{ project(name: $project, entityName: $entity) {{ internalId runCount(filters: $filters) readOnly runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{ edges {{ node {{ {"" if _server_provides_internal_id_for_project(client) else "internalId"} ...RunFragment }} cursor }} pageInfo {{ endCursor hasNextPage }} }} }} }} {RUN_FRAGMENT} """ ) self.entity = entity self.project = project self._project_internal_id = None self.filters = filters or {} self.order = order self._sweeps = {} self._include_sweeps = include_sweeps variables = { "project": self.project, "entity": self.entity, "order": self.order, "filters": json.dumps(self.filters), } super().__init__(client, variables, per_page) @property def length(self): if self.last_response: return self.last_response["project"]["runCount"] else: return None @property def more(self): if self.last_response: return self.last_response["project"]["runs"]["pageInfo"]["hasNextPage"] else: return True @property def cursor(self): if self.last_response: return self.last_response["project"]["runs"]["edges"][-1]["cursor"] else: return None def convert_objects(self): objs = [] if self.last_response is None or self.last_response.get("project") is None: raise ValueError("Could not find project {}".format(self.project)) for run_response in self.last_response["project"]["runs"]["edges"]: run = Run( self.client, self.entity, self.project, run_response["node"]["name"], run_response["node"], include_sweeps=self._include_sweeps, ) objs.append(run) if self._include_sweeps and run.sweep_name: if run.sweep_name in self._sweeps: sweep = self._sweeps[run.sweep_name] else: sweep = public.Sweep.get( self.client, self.entity, self.project, run.sweep_name, withRuns=False, ) self._sweeps[run.sweep_name] = sweep if sweep is None: continue run.sweep = sweep return objs @normalize_exceptions def histories( self, samples: int = 500, keys: Optional[List[str]] = None, x_axis: str = "_step", format: Literal["default", "pandas", "polars"] = "default", stream: Literal["default", "system"] = "default", ): """Return sampled history metrics for all runs that fit the filters conditions. Args: samples : (int, optional) The number of samples to return per run keys : (list[str], optional) Only return metrics for specific keys x_axis : (str, optional) Use this metric as the xAxis defaults to _step format : (Literal, optional) Format to return data in, options are "default", "pandas", "polars" stream : (Literal, optional) "default" for metrics, "system" for machine metrics Returns: pandas.DataFrame: If format="pandas", returns a `pandas.DataFrame` of history metrics. polars.DataFrame: If format="polars", returns a `polars.DataFrame` of history metrics. list of dicts: If format="default", returns a list of dicts containing history metrics with a run_id key. """ if format not in ("default", "pandas", "polars"): raise ValueError( f"Invalid format: {format}. Must be one of 'default', 'pandas', 'polars'" ) histories = [] if format == "default": for run in self: history_data = run.history( samples=samples, keys=keys, x_axis=x_axis, pandas=False, stream=stream, ) if not history_data: continue for entry in history_data: entry["run_id"] = run.id histories.extend(history_data) return histories if format == "pandas": pd = util.get_module( "pandas", required="Exporting pandas DataFrame requires pandas" ) for run in self: history_data = run.history( samples=samples, keys=keys, x_axis=x_axis, pandas=False, stream=stream, ) if not history_data: continue df = pd.DataFrame.from_records(history_data) df["run_id"] = run.id histories.append(df) if not histories: return pd.DataFrame() combined_df = pd.concat(histories) combined_df.reset_index(drop=True, inplace=True) # sort columns for consistency combined_df = combined_df[(sorted(combined_df.columns))] return combined_df if format == "polars": pl = util.get_module( "polars", required="Exporting polars DataFrame requires polars" ) for run in self: history_data = run.history( samples=samples, keys=keys, x_axis=x_axis, pandas=False, stream=stream, ) if not history_data: continue df = pl.from_records(history_data) df = df.with_columns(pl.lit(run.id).alias("run_id")) histories.append(df) if not histories: return pl.DataFrame() combined_df = pl.concat(histories, how="vertical") # sort columns for consistency combined_df = combined_df.select(sorted(combined_df.columns)) return combined_df def __repr__(self): return f"" class Run(Attrs): """A single run associated with an entity and project. Attributes: tags ([str]): a list of tags associated with the run url (str): the url of this run id (str): unique identifier for the run (defaults to eight characters) name (str): the name of the run state (str): one of: running, finished, crashed, killed, preempting, preempted config (dict): a dict of hyperparameters associated with the run created_at (str): ISO timestamp when the run was started system_metrics (dict): the latest system metrics recorded for the run summary (dict): A mutable dict-like property that holds the current summary. Calling update will persist any changes. project (str): the project associated with the run entity (str): the name of the entity associated with the run project_internal_id (int): the internal id of the project user (str): the name of the user who created the run path (str): Unique identifier [entity]/[project]/[run_id] notes (str): Notes about the run read_only (boolean): Whether the run is editable history_keys (str): Keys of the history metrics that have been logged with `wandb.log({key: value})` metadata (str): Metadata about the run from wandb-metadata.json """ def __init__( self, client: "RetryingClient", entity: str, project: str, run_id: str, attrs: Optional[Mapping] = None, include_sweeps: bool = True, ): """Initialize a Run object. Run is always initialized by calling api.runs() where api is an instance of wandb.Api. """ _attrs = attrs or {} super().__init__(dict(_attrs)) self.client = client self._entity = entity self.project = project self._files = {} self._base_dir = env.get_dir(tempfile.gettempdir()) self.id = run_id self.sweep = None self._include_sweeps = include_sweeps self.dir = os.path.join(self._base_dir, *self.path) try: os.makedirs(self.dir) except OSError: pass self._summary = None self._metadata: Optional[Dict[str, Any]] = None self._state = _attrs.get("state", "not found") self.server_provides_internal_id_field: Optional[bool] = None self.load(force=not _attrs) @property def state(self): return self._state @property def entity(self): return self._entity @property def username(self): wandb.termwarn("Run.username is deprecated. Please use Run.entity instead.") return self._entity @property def storage_id(self): # For compatibility with wandb.Run, which has storage IDs # in self.storage_id and names in self.id. return self._attrs.get("id") @property def id(self): return self._attrs.get("name") @id.setter def id(self, new_id): attrs = self._attrs attrs["name"] = new_id return new_id @property def name(self): return self._attrs.get("displayName") @name.setter def name(self, new_name): self._attrs["displayName"] = new_name return new_name @classmethod def create( cls, api, run_id=None, project=None, entity=None, state: Literal["running", "pending"] = "running", ): """Create a run for the given project.""" run_id = run_id or runid.generate_id() project = project or api.settings.get("project") or "uncategorized" mutation = gql( """ mutation UpsertBucket($project: String, $entity: String, $name: String!, $state: String) { upsertBucket(input: {modelName: $project, entityName: $entity, name: $name, state: $state}) { bucket { project { name entity { name } } id name } inserted } } """ ) variables = { "entity": entity, "project": project, "name": run_id, "state": state, } res = api.client.execute(mutation, variable_values=variables) res = res["upsertBucket"]["bucket"] return Run( api.client, res["project"]["entity"]["name"], res["project"]["name"], res["name"], { "id": res["id"], "config": "{}", "systemMetrics": "{}", "summaryMetrics": "{}", "tags": [], "description": None, "notes": None, "state": state, }, ) def load(self, force=False): query = gql( """ query Run($project: String!, $entity: String!, $name: String!) {{ project(name: $project, entityName: $entity) {{ run(name: $name) {{ {} ...RunFragment }} }} }} {} """.format( "projectId" if _server_provides_internal_id_for_project(self.client) else "", RUN_FRAGMENT, ) ) if force or not self._attrs: response = self._exec(query) if ( response is None or response.get("project") is None or response["project"].get("run") is None ): raise ValueError("Could not find run {}".format(self)) self._attrs = response["project"]["run"] self._state = self._attrs["state"] if self._include_sweeps and self.sweep_name and not self.sweep: # There may be a lot of runs. Don't bother pulling them all # just for the sake of this one. self.sweep = public.Sweep.get( self.client, self.entity, self.project, self.sweep_name, withRuns=False, ) if "projectId" in self._attrs: self._project_internal_id = int(self._attrs["projectId"]) else: self._project_internal_id = None try: self._attrs["summaryMetrics"] = ( json.loads(self._attrs["summaryMetrics"]) if self._attrs.get("summaryMetrics") else {} ) except json.decoder.JSONDecodeError: # ignore invalid utf-8 or control characters self._attrs["summaryMetrics"] = json.loads( self._attrs["summaryMetrics"], strict=False, ) self._attrs["systemMetrics"] = ( json.loads(self._attrs["systemMetrics"]) if self._attrs.get("systemMetrics") else {} ) if self._attrs.get("user"): self.user = public.User(self.client, self._attrs["user"]) config_user, config_raw = {}, {} for key, value in json.loads(self._attrs.get("config") or "{}").items(): config = config_raw if key in WANDB_INTERNAL_KEYS else config_user if isinstance(value, dict) and "value" in value: config[key] = value["value"] else: config[key] = value config_raw.update(config_user) self._attrs["config"] = config_user self._attrs["rawconfig"] = config_raw return self._attrs @normalize_exceptions def wait_until_finished(self): query = gql( """ query RunState($project: String!, $entity: String!, $name: String!) { project(name: $project, entityName: $entity) { run(name: $name) { state } } } """ ) while True: res = self._exec(query) state = res["project"]["run"]["state"] if state in ["finished", "crashed", "failed"]: self._attrs["state"] = state self._state = state return time.sleep(5) @normalize_exceptions def update(self): """Persist changes to the run object to the wandb backend.""" mutation = gql( """ mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String, $jobType: String) {{ upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName, jobType: $jobType}}) {{ bucket {{ ...RunFragment }} }} }} {} """.format(RUN_FRAGMENT) ) _ = self._exec( mutation, id=self.storage_id, tags=self.tags, description=self.description, notes=self.notes, display_name=self.display_name, config=self.json_config, groupName=self.group, jobType=self.job_type, ) self.summary.update() @normalize_exceptions def delete(self, delete_artifacts=False): """Delete the given run from the wandb backend.""" mutation = gql( """ mutation DeleteRun( $id: ID!, {} ) {{ deleteRun(input: {{ id: $id, {} }}) {{ clientMutationId }} }} """.format( "$deleteArtifacts: Boolean" if delete_artifacts else "", "deleteArtifacts: $deleteArtifacts" if delete_artifacts else "", ) ) self.client.execute( mutation, variable_values={ "id": self.storage_id, "deleteArtifacts": delete_artifacts, }, ) def save(self): self.update() @property def json_config(self): config = {} if "_wandb" in self.rawconfig: config["_wandb"] = {"value": self.rawconfig["_wandb"], "desc": None} for k, v in self.config.items(): config[k] = {"value": v, "desc": None} return json.dumps(config) def _exec(self, query, **kwargs): """Execute a query against the cloud backend.""" variables = {"entity": self.entity, "project": self.project, "name": self.id} variables.update(kwargs) return self.client.execute(query, variable_values=variables) def _sampled_history(self, keys, x_axis="_step", samples=500): spec = {"keys": [x_axis] + keys, "samples": samples} query = gql( """ query RunSampledHistory($project: String!, $entity: String!, $name: String!, $specs: [JSONString!]!) { project(name: $project, entityName: $entity) { run(name: $name) { sampledHistory(specs: $specs) } } } """ ) response = self._exec(query, specs=[json.dumps(spec)]) # sampledHistory returns one list per spec, we only send one spec return response["project"]["run"]["sampledHistory"][0] def _full_history(self, samples=500, stream="default"): node = "history" if stream == "default" else "events" query = gql( """ query RunFullHistory($project: String!, $entity: String!, $name: String!, $samples: Int) {{ project(name: $project, entityName: $entity) {{ run(name: $name) {{ {}(samples: $samples) }} }} }} """.format(node) ) response = self._exec(query, samples=samples) return [json.loads(line) for line in response["project"]["run"][node]] @normalize_exceptions def files(self, names=None, per_page=50): """Return a file path for each file named. Args: names (list): names of the requested files, if empty returns all files per_page (int): number of results per page. Returns: A `Files` object, which is an iterator over `File` objects. """ return public.Files(self.client, self, names or [], per_page) @normalize_exceptions def file(self, name): """Return the path of a file with a given name in the artifact. Args: name (str): name of requested file. Returns: A `File` matching the name argument. """ return public.Files(self.client, self, [name])[0] @normalize_exceptions def upload_file(self, path, root="."): """Upload a file. Args: path (str): name of file to upload. root (str): the root path to save the file relative to. i.e. If you want to have the file saved in the run as "my_dir/file.txt" and you're currently in "my_dir" you would set root to "../". Returns: A `File` matching the name argument. """ api = InternalApi( default_settings={"entity": self.entity, "project": self.project}, retry_timedelta=RETRY_TIMEDELTA, ) api.set_current_run_id(self.id) root = os.path.abspath(root) name = os.path.relpath(path, root) with open(os.path.join(root, name), "rb") as f: api.push({LogicalPath(name): f}) return public.Files(self.client, self, [name])[0] @normalize_exceptions def history( self, samples=500, keys=None, x_axis="_step", pandas=True, stream="default" ): """Return sampled history metrics for a run. This is simpler and faster if you are ok with the history records being sampled. Args: samples : (int, optional) The number of samples to return pandas : (bool, optional) Return a pandas dataframe keys : (list, optional) Only return metrics for specific keys x_axis : (str, optional) Use this metric as the xAxis defaults to _step stream : (str, optional) "default" for metrics, "system" for machine metrics Returns: pandas.DataFrame: If pandas=True returns a `pandas.DataFrame` of history metrics. list of dicts: If pandas=False returns a list of dicts of history metrics. """ if keys is not None and not isinstance(keys, list): wandb.termerror("keys must be specified in a list") return [] if keys is not None and len(keys) > 0 and not isinstance(keys[0], str): wandb.termerror("keys argument must be a list of strings") return [] if keys and stream != "default": wandb.termerror("stream must be default when specifying keys") return [] elif keys: lines = self._sampled_history(keys=keys, x_axis=x_axis, samples=samples) else: lines = self._full_history(samples=samples, stream=stream) if pandas: pd = util.get_module("pandas") if pd: lines = pd.DataFrame.from_records(lines) else: wandb.termwarn("Unable to load pandas, call history with pandas=False") return lines @normalize_exceptions def scan_history(self, keys=None, page_size=1000, min_step=None, max_step=None): """Returns an iterable collection of all history records for a run. Example: Export all the loss values for an example run ```python run = api.run("l2k2/examples-numpy-boston/i0wt6xua") history = run.scan_history(keys=["Loss"]) losses = [row["Loss"] for row in history] ``` Args: keys ([str], optional): only fetch these keys, and only fetch rows that have all of keys defined. page_size (int, optional): size of pages to fetch from the api. min_step (int, optional): the minimum number of pages to scan at a time. max_step (int, optional): the maximum number of pages to scan at a time. Returns: An iterable collection over history records (dict). """ if keys is not None and not isinstance(keys, list): wandb.termerror("keys must be specified in a list") return [] if keys is not None and len(keys) > 0 and not isinstance(keys[0], str): wandb.termerror("keys argument must be a list of strings") return [] last_step = self.lastHistoryStep # set defaults for min/max step if min_step is None: min_step = 0 if max_step is None: max_step = last_step + 1 # if the max step is past the actual last step, clamp it down if max_step > last_step: max_step = last_step + 1 if keys is None: return public.HistoryScan( run=self, client=self.client, page_size=page_size, min_step=min_step, max_step=max_step, ) else: return public.SampledHistoryScan( run=self, client=self.client, keys=keys, page_size=page_size, min_step=min_step, max_step=max_step, ) @normalize_exceptions def logged_artifacts(self, per_page: int = 100) -> public.RunArtifacts: """Fetches all artifacts logged by this run. Retrieves all output artifacts that were logged during the run. Returns a paginated result that can be iterated over or collected into a single list. Args: per_page: Number of artifacts to fetch per API request. Returns: An iterable collection of all Artifact objects logged as outputs during this run. Example: >>> import wandb >>> import tempfile >>> with tempfile.NamedTemporaryFile( ... mode="w", delete=False, suffix=".txt" ... ) as tmp: ... tmp.write("This is a test artifact") ... tmp_path = tmp.name >>> run = wandb.init(project="artifact-example") >>> artifact = wandb.Artifact("test_artifact", type="dataset") >>> artifact.add_file(tmp_path) >>> run.log_artifact(artifact) >>> run.finish() >>> api = wandb.Api() >>> finished_run = api.run(f"{run.entity}/{run.project}/{run.id}") >>> for logged_artifact in finished_run.logged_artifacts(): ... print(logged_artifact.name) test_artifact """ return public.RunArtifacts(self.client, self, mode="logged", per_page=per_page) @normalize_exceptions def used_artifacts(self, per_page: int = 100) -> public.RunArtifacts: """Fetches artifacts explicitly used by this run. Retrieves only the input artifacts that were explicitly declared as used during the run, typically via `run.use_artifact()`. Returns a paginated result that can be iterated over or collected into a single list. Args: per_page: Number of artifacts to fetch per API request. Returns: An iterable collection of Artifact objects explicitly used as inputs in this run. Example: >>> import wandb >>> run = wandb.init(project="artifact-example") >>> run.use_artifact("test_artifact:latest") >>> run.finish() >>> api = wandb.Api() >>> finished_run = api.run(f"{run.entity}/{run.project}/{run.id}") >>> for used_artifact in finished_run.used_artifacts(): ... print(used_artifact.name) test_artifact """ return public.RunArtifacts(self.client, self, mode="used", per_page=per_page) @normalize_exceptions def use_artifact(self, artifact, use_as=None): """Declare an artifact as an input to a run. Args: artifact (`Artifact`): An artifact returned from `wandb.Api().artifact(name)` use_as (string, optional): A string identifying how the artifact is used in the script. Used to easily differentiate artifacts used in a run, when using the beta wandb launch feature's artifact swapping functionality. Returns: A `Artifact` object. """ api = InternalApi( default_settings={"entity": self.entity, "project": self.project}, retry_timedelta=RETRY_TIMEDELTA, ) api.set_current_run_id(self.id) if isinstance(artifact, wandb.Artifact) and not artifact.is_draft(): api.use_artifact( artifact.id, use_as=use_as or artifact.name, artifact_entity_name=artifact.entity, artifact_project_name=artifact.project, ) return artifact elif isinstance(artifact, wandb.Artifact) and artifact.is_draft(): raise ValueError( "Only existing artifacts are accepted by this api. " "Manually create one with `wandb artifact put`" ) else: raise ValueError("You must pass a wandb.Api().artifact() to use_artifact") @normalize_exceptions def log_artifact( self, artifact: "wandb.Artifact", aliases: Optional[Collection[str]] = None, tags: Optional[Collection[str]] = None, ): """Declare an artifact as output of a run. Args: artifact (`Artifact`): An artifact returned from `wandb.Api().artifact(name)`. aliases (list, optional): Aliases to apply to this artifact. tags: (list, optional) Tags to apply to this artifact, if any. Returns: A `Artifact` object. """ api = InternalApi( default_settings={"entity": self.entity, "project": self.project}, retry_timedelta=RETRY_TIMEDELTA, ) api.set_current_run_id(self.id) if not isinstance(artifact, wandb.Artifact): raise TypeError("You must pass a wandb.Api().artifact() to use_artifact") if artifact.is_draft(): raise ValueError( "Only existing artifacts are accepted by this api. " "Manually create one with `wandb artifact put`" ) if ( self.entity != artifact.source_entity or self.project != artifact.source_project ): raise ValueError("A run can't log an artifact to a different project.") artifact_collection_name = artifact.source_name.split(":")[0] api.create_artifact( artifact.type, artifact_collection_name, artifact.digest, aliases=aliases, tags=tags, ) return artifact @property def summary(self): if self._summary is None: from wandb.old.summary import HTTPSummary # TODO: fix the outdir issue self._summary = HTTPSummary(self, self.client, summary=self.summary_metrics) return self._summary @property def path(self): return [ urllib.parse.quote_plus(str(self.entity)), urllib.parse.quote_plus(str(self.project)), urllib.parse.quote_plus(str(self.id)), ] @property def url(self): path = self.path path.insert(2, "runs") return self.client.app_url + "/".join(path) @property def metadata(self): if self._metadata is None: try: f = self.file("wandb-metadata.json") session = self.client._client.transport.session response = session.get(f.url, timeout=5) response.raise_for_status() contents = response.content self._metadata = json_util.loads(contents) except: # noqa: E722 # file doesn't exist, or can't be downloaded, or can't be parsed pass return self._metadata @property def lastHistoryStep(self): # noqa: N802 query = gql( """ query RunHistoryKeys($project: String!, $entity: String!, $name: String!) { project(name: $project, entityName: $entity) { run(name: $name) { historyKeys } } } """ ) response = self._exec(query) if ( response is None or response.get("project") is None or response["project"].get("run") is None or response["project"]["run"].get("historyKeys") is None ): return -1 history_keys = response["project"]["run"]["historyKeys"] return history_keys["lastStep"] if "lastStep" in history_keys else -1 def to_html(self, height=420, hidden=False): """Generate HTML containing an iframe displaying this run.""" url = self.url + "?jupyter=true" style = f"border:none;width:100%;height:{height}px;" prefix = "" if hidden: style += "display:none;" prefix = ipython.toggle_button() return prefix + f"" def _repr_html_(self) -> str: return self.to_html() def __repr__(self): return "".format("/".join(self.path), self.state)