File size: 35,584 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 |
"""Public API: runs."""
import json
import os
import tempfile
import time
import urllib
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
List,
Literal,
Mapping,
Optional,
)
from wandb_gql import gql
import wandb
from wandb import env, util
from wandb.apis import public
from wandb.apis.attrs import Attrs
from wandb.apis.internal import Api as InternalApi
from wandb.apis.normalize import normalize_exceptions
from wandb.apis.paginator import SizedPaginator
from wandb.apis.public.const import RETRY_TIMEDELTA
from wandb.sdk.lib import ipython, json_util, runid
from wandb.sdk.lib.paths import LogicalPath
if TYPE_CHECKING:
from wandb.apis.public import RetryingClient
WANDB_INTERNAL_KEYS = {"_wandb", "wandb_version"}
RUN_FRAGMENT = """fragment RunFragment on Run {
id
tags
name
displayName
sweepName
state
config
group
jobType
commit
readOnly
createdAt
heartbeatAt
description
notes
systemMetrics
summaryMetrics
historyLineCount
user {
name
username
}
historyKeys
}"""
@normalize_exceptions
def _server_provides_internal_id_for_project(client) -> bool:
"""Returns True if the server allows us to query the internalId field for a project.
This check is done by utilizing GraphQL introspection in the available fields on the Project type.
"""
query_string = """
query ProbeRunInput {
RunType: __type(name:"Run") {
fields {
name
}
}
}
"""
# Only perform the query once to avoid extra network calls
query = gql(query_string)
res = client.execute(query)
return "projectId" in [
x["name"] for x in (res.get("RunType", {}).get("fields", [{}]))
]
class Runs(SizedPaginator["Run"]):
"""An iterable collection of runs associated with a project and optional filter.
This is generally used indirectly via the `Api`.runs method.
"""
def __init__(
self,
client: "RetryingClient",
entity: str,
project: str,
filters: Optional[Dict[str, Any]] = None,
order: Optional[str] = None,
per_page: int = 50,
include_sweeps: bool = True,
):
self.QUERY = gql(
f"""#graphql
query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
project(name: $project, entityName: $entity) {{
internalId
runCount(filters: $filters)
readOnly
runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
edges {{
node {{
{"" if _server_provides_internal_id_for_project(client) else "internalId"}
...RunFragment
}}
cursor
}}
pageInfo {{
endCursor
hasNextPage
}}
}}
}}
}}
{RUN_FRAGMENT}
"""
)
self.entity = entity
self.project = project
self._project_internal_id = None
self.filters = filters or {}
self.order = order
self._sweeps = {}
self._include_sweeps = include_sweeps
variables = {
"project": self.project,
"entity": self.entity,
"order": self.order,
"filters": json.dumps(self.filters),
}
super().__init__(client, variables, per_page)
@property
def length(self):
if self.last_response:
return self.last_response["project"]["runCount"]
else:
return None
@property
def more(self):
if self.last_response:
return self.last_response["project"]["runs"]["pageInfo"]["hasNextPage"]
else:
return True
@property
def cursor(self):
if self.last_response:
return self.last_response["project"]["runs"]["edges"][-1]["cursor"]
else:
return None
def convert_objects(self):
objs = []
if self.last_response is None or self.last_response.get("project") is None:
raise ValueError("Could not find project {}".format(self.project))
for run_response in self.last_response["project"]["runs"]["edges"]:
run = Run(
self.client,
self.entity,
self.project,
run_response["node"]["name"],
run_response["node"],
include_sweeps=self._include_sweeps,
)
objs.append(run)
if self._include_sweeps and run.sweep_name:
if run.sweep_name in self._sweeps:
sweep = self._sweeps[run.sweep_name]
else:
sweep = public.Sweep.get(
self.client,
self.entity,
self.project,
run.sweep_name,
withRuns=False,
)
self._sweeps[run.sweep_name] = sweep
if sweep is None:
continue
run.sweep = sweep
return objs
@normalize_exceptions
def histories(
self,
samples: int = 500,
keys: Optional[List[str]] = None,
x_axis: str = "_step",
format: Literal["default", "pandas", "polars"] = "default",
stream: Literal["default", "system"] = "default",
):
"""Return sampled history metrics for all runs that fit the filters conditions.
Args:
samples : (int, optional) The number of samples to return per run
keys : (list[str], optional) Only return metrics for specific keys
x_axis : (str, optional) Use this metric as the xAxis defaults to _step
format : (Literal, optional) Format to return data in, options are "default", "pandas", "polars"
stream : (Literal, optional) "default" for metrics, "system" for machine metrics
Returns:
pandas.DataFrame: If format="pandas", returns a `pandas.DataFrame` of history metrics.
polars.DataFrame: If format="polars", returns a `polars.DataFrame` of history metrics.
list of dicts: If format="default", returns a list of dicts containing history metrics with a run_id key.
"""
if format not in ("default", "pandas", "polars"):
raise ValueError(
f"Invalid format: {format}. Must be one of 'default', 'pandas', 'polars'"
)
histories = []
if format == "default":
for run in self:
history_data = run.history(
samples=samples,
keys=keys,
x_axis=x_axis,
pandas=False,
stream=stream,
)
if not history_data:
continue
for entry in history_data:
entry["run_id"] = run.id
histories.extend(history_data)
return histories
if format == "pandas":
pd = util.get_module(
"pandas", required="Exporting pandas DataFrame requires pandas"
)
for run in self:
history_data = run.history(
samples=samples,
keys=keys,
x_axis=x_axis,
pandas=False,
stream=stream,
)
if not history_data:
continue
df = pd.DataFrame.from_records(history_data)
df["run_id"] = run.id
histories.append(df)
if not histories:
return pd.DataFrame()
combined_df = pd.concat(histories)
combined_df.reset_index(drop=True, inplace=True)
# sort columns for consistency
combined_df = combined_df[(sorted(combined_df.columns))]
return combined_df
if format == "polars":
pl = util.get_module(
"polars", required="Exporting polars DataFrame requires polars"
)
for run in self:
history_data = run.history(
samples=samples,
keys=keys,
x_axis=x_axis,
pandas=False,
stream=stream,
)
if not history_data:
continue
df = pl.from_records(history_data)
df = df.with_columns(pl.lit(run.id).alias("run_id"))
histories.append(df)
if not histories:
return pl.DataFrame()
combined_df = pl.concat(histories, how="vertical")
# sort columns for consistency
combined_df = combined_df.select(sorted(combined_df.columns))
return combined_df
def __repr__(self):
return f"<Runs {self.entity}/{self.project}>"
class Run(Attrs):
"""A single run associated with an entity and project.
Attributes:
tags ([str]): a list of tags associated with the run
url (str): the url of this run
id (str): unique identifier for the run (defaults to eight characters)
name (str): the name of the run
state (str): one of: running, finished, crashed, killed, preempting, preempted
config (dict): a dict of hyperparameters associated with the run
created_at (str): ISO timestamp when the run was started
system_metrics (dict): the latest system metrics recorded for the run
summary (dict): A mutable dict-like property that holds the current summary.
Calling update will persist any changes.
project (str): the project associated with the run
entity (str): the name of the entity associated with the run
project_internal_id (int): the internal id of the project
user (str): the name of the user who created the run
path (str): Unique identifier [entity]/[project]/[run_id]
notes (str): Notes about the run
read_only (boolean): Whether the run is editable
history_keys (str): Keys of the history metrics that have been logged
with `wandb.log({key: value})`
metadata (str): Metadata about the run from wandb-metadata.json
"""
def __init__(
self,
client: "RetryingClient",
entity: str,
project: str,
run_id: str,
attrs: Optional[Mapping] = None,
include_sweeps: bool = True,
):
"""Initialize a Run object.
Run is always initialized by calling api.runs() where api is an instance of
wandb.Api.
"""
_attrs = attrs or {}
super().__init__(dict(_attrs))
self.client = client
self._entity = entity
self.project = project
self._files = {}
self._base_dir = env.get_dir(tempfile.gettempdir())
self.id = run_id
self.sweep = None
self._include_sweeps = include_sweeps
self.dir = os.path.join(self._base_dir, *self.path)
try:
os.makedirs(self.dir)
except OSError:
pass
self._summary = None
self._metadata: Optional[Dict[str, Any]] = None
self._state = _attrs.get("state", "not found")
self.server_provides_internal_id_field: Optional[bool] = None
self.load(force=not _attrs)
@property
def state(self):
return self._state
@property
def entity(self):
return self._entity
@property
def username(self):
wandb.termwarn("Run.username is deprecated. Please use Run.entity instead.")
return self._entity
@property
def storage_id(self):
# For compatibility with wandb.Run, which has storage IDs
# in self.storage_id and names in self.id.
return self._attrs.get("id")
@property
def id(self):
return self._attrs.get("name")
@id.setter
def id(self, new_id):
attrs = self._attrs
attrs["name"] = new_id
return new_id
@property
def name(self):
return self._attrs.get("displayName")
@name.setter
def name(self, new_name):
self._attrs["displayName"] = new_name
return new_name
@classmethod
def create(
cls,
api,
run_id=None,
project=None,
entity=None,
state: Literal["running", "pending"] = "running",
):
"""Create a run for the given project."""
run_id = run_id or runid.generate_id()
project = project or api.settings.get("project") or "uncategorized"
mutation = gql(
"""
mutation UpsertBucket($project: String, $entity: String, $name: String!, $state: String) {
upsertBucket(input: {modelName: $project, entityName: $entity, name: $name, state: $state}) {
bucket {
project {
name
entity { name }
}
id
name
}
inserted
}
}
"""
)
variables = {
"entity": entity,
"project": project,
"name": run_id,
"state": state,
}
res = api.client.execute(mutation, variable_values=variables)
res = res["upsertBucket"]["bucket"]
return Run(
api.client,
res["project"]["entity"]["name"],
res["project"]["name"],
res["name"],
{
"id": res["id"],
"config": "{}",
"systemMetrics": "{}",
"summaryMetrics": "{}",
"tags": [],
"description": None,
"notes": None,
"state": state,
},
)
def load(self, force=False):
query = gql(
"""
query Run($project: String!, $entity: String!, $name: String!) {{
project(name: $project, entityName: $entity) {{
run(name: $name) {{
{}
...RunFragment
}}
}}
}}
{}
""".format(
"projectId"
if _server_provides_internal_id_for_project(self.client)
else "",
RUN_FRAGMENT,
)
)
if force or not self._attrs:
response = self._exec(query)
if (
response is None
or response.get("project") is None
or response["project"].get("run") is None
):
raise ValueError("Could not find run {}".format(self))
self._attrs = response["project"]["run"]
self._state = self._attrs["state"]
if self._include_sweeps and self.sweep_name and not self.sweep:
# There may be a lot of runs. Don't bother pulling them all
# just for the sake of this one.
self.sweep = public.Sweep.get(
self.client,
self.entity,
self.project,
self.sweep_name,
withRuns=False,
)
if "projectId" in self._attrs:
self._project_internal_id = int(self._attrs["projectId"])
else:
self._project_internal_id = None
try:
self._attrs["summaryMetrics"] = (
json.loads(self._attrs["summaryMetrics"])
if self._attrs.get("summaryMetrics")
else {}
)
except json.decoder.JSONDecodeError:
# ignore invalid utf-8 or control characters
self._attrs["summaryMetrics"] = json.loads(
self._attrs["summaryMetrics"],
strict=False,
)
self._attrs["systemMetrics"] = (
json.loads(self._attrs["systemMetrics"])
if self._attrs.get("systemMetrics")
else {}
)
if self._attrs.get("user"):
self.user = public.User(self.client, self._attrs["user"])
config_user, config_raw = {}, {}
for key, value in json.loads(self._attrs.get("config") or "{}").items():
config = config_raw if key in WANDB_INTERNAL_KEYS else config_user
if isinstance(value, dict) and "value" in value:
config[key] = value["value"]
else:
config[key] = value
config_raw.update(config_user)
self._attrs["config"] = config_user
self._attrs["rawconfig"] = config_raw
return self._attrs
@normalize_exceptions
def wait_until_finished(self):
query = gql(
"""
query RunState($project: String!, $entity: String!, $name: String!) {
project(name: $project, entityName: $entity) {
run(name: $name) {
state
}
}
}
"""
)
while True:
res = self._exec(query)
state = res["project"]["run"]["state"]
if state in ["finished", "crashed", "failed"]:
self._attrs["state"] = state
self._state = state
return
time.sleep(5)
@normalize_exceptions
def update(self):
"""Persist changes to the run object to the wandb backend."""
mutation = gql(
"""
mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String, $jobType: String) {{
upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName, jobType: $jobType}}) {{
bucket {{
...RunFragment
}}
}}
}}
{}
""".format(RUN_FRAGMENT)
)
_ = self._exec(
mutation,
id=self.storage_id,
tags=self.tags,
description=self.description,
notes=self.notes,
display_name=self.display_name,
config=self.json_config,
groupName=self.group,
jobType=self.job_type,
)
self.summary.update()
@normalize_exceptions
def delete(self, delete_artifacts=False):
"""Delete the given run from the wandb backend."""
mutation = gql(
"""
mutation DeleteRun(
$id: ID!,
{}
) {{
deleteRun(input: {{
id: $id,
{}
}}) {{
clientMutationId
}}
}}
""".format(
"$deleteArtifacts: Boolean" if delete_artifacts else "",
"deleteArtifacts: $deleteArtifacts" if delete_artifacts else "",
)
)
self.client.execute(
mutation,
variable_values={
"id": self.storage_id,
"deleteArtifacts": delete_artifacts,
},
)
def save(self):
self.update()
@property
def json_config(self):
config = {}
if "_wandb" in self.rawconfig:
config["_wandb"] = {"value": self.rawconfig["_wandb"], "desc": None}
for k, v in self.config.items():
config[k] = {"value": v, "desc": None}
return json.dumps(config)
def _exec(self, query, **kwargs):
"""Execute a query against the cloud backend."""
variables = {"entity": self.entity, "project": self.project, "name": self.id}
variables.update(kwargs)
return self.client.execute(query, variable_values=variables)
def _sampled_history(self, keys, x_axis="_step", samples=500):
spec = {"keys": [x_axis] + keys, "samples": samples}
query = gql(
"""
query RunSampledHistory($project: String!, $entity: String!, $name: String!, $specs: [JSONString!]!) {
project(name: $project, entityName: $entity) {
run(name: $name) { sampledHistory(specs: $specs) }
}
}
"""
)
response = self._exec(query, specs=[json.dumps(spec)])
# sampledHistory returns one list per spec, we only send one spec
return response["project"]["run"]["sampledHistory"][0]
def _full_history(self, samples=500, stream="default"):
node = "history" if stream == "default" else "events"
query = gql(
"""
query RunFullHistory($project: String!, $entity: String!, $name: String!, $samples: Int) {{
project(name: $project, entityName: $entity) {{
run(name: $name) {{ {}(samples: $samples) }}
}}
}}
""".format(node)
)
response = self._exec(query, samples=samples)
return [json.loads(line) for line in response["project"]["run"][node]]
@normalize_exceptions
def files(self, names=None, per_page=50):
"""Return a file path for each file named.
Args:
names (list): names of the requested files, if empty returns all files
per_page (int): number of results per page.
Returns:
A `Files` object, which is an iterator over `File` objects.
"""
return public.Files(self.client, self, names or [], per_page)
@normalize_exceptions
def file(self, name):
"""Return the path of a file with a given name in the artifact.
Args:
name (str): name of requested file.
Returns:
A `File` matching the name argument.
"""
return public.Files(self.client, self, [name])[0]
@normalize_exceptions
def upload_file(self, path, root="."):
"""Upload a file.
Args:
path (str): name of file to upload.
root (str): the root path to save the file relative to. i.e.
If you want to have the file saved in the run as "my_dir/file.txt"
and you're currently in "my_dir" you would set root to "../".
Returns:
A `File` matching the name argument.
"""
api = InternalApi(
default_settings={"entity": self.entity, "project": self.project},
retry_timedelta=RETRY_TIMEDELTA,
)
api.set_current_run_id(self.id)
root = os.path.abspath(root)
name = os.path.relpath(path, root)
with open(os.path.join(root, name), "rb") as f:
api.push({LogicalPath(name): f})
return public.Files(self.client, self, [name])[0]
@normalize_exceptions
def history(
self, samples=500, keys=None, x_axis="_step", pandas=True, stream="default"
):
"""Return sampled history metrics for a run.
This is simpler and faster if you are ok with the history records being sampled.
Args:
samples : (int, optional) The number of samples to return
pandas : (bool, optional) Return a pandas dataframe
keys : (list, optional) Only return metrics for specific keys
x_axis : (str, optional) Use this metric as the xAxis defaults to _step
stream : (str, optional) "default" for metrics, "system" for machine metrics
Returns:
pandas.DataFrame: If pandas=True returns a `pandas.DataFrame` of history
metrics.
list of dicts: If pandas=False returns a list of dicts of history metrics.
"""
if keys is not None and not isinstance(keys, list):
wandb.termerror("keys must be specified in a list")
return []
if keys is not None and len(keys) > 0 and not isinstance(keys[0], str):
wandb.termerror("keys argument must be a list of strings")
return []
if keys and stream != "default":
wandb.termerror("stream must be default when specifying keys")
return []
elif keys:
lines = self._sampled_history(keys=keys, x_axis=x_axis, samples=samples)
else:
lines = self._full_history(samples=samples, stream=stream)
if pandas:
pd = util.get_module("pandas")
if pd:
lines = pd.DataFrame.from_records(lines)
else:
wandb.termwarn("Unable to load pandas, call history with pandas=False")
return lines
@normalize_exceptions
def scan_history(self, keys=None, page_size=1000, min_step=None, max_step=None):
"""Returns an iterable collection of all history records for a run.
Example:
Export all the loss values for an example run
```python
run = api.run("l2k2/examples-numpy-boston/i0wt6xua")
history = run.scan_history(keys=["Loss"])
losses = [row["Loss"] for row in history]
```
Args:
keys ([str], optional): only fetch these keys, and only fetch rows that have all of keys defined.
page_size (int, optional): size of pages to fetch from the api.
min_step (int, optional): the minimum number of pages to scan at a time.
max_step (int, optional): the maximum number of pages to scan at a time.
Returns:
An iterable collection over history records (dict).
"""
if keys is not None and not isinstance(keys, list):
wandb.termerror("keys must be specified in a list")
return []
if keys is not None and len(keys) > 0 and not isinstance(keys[0], str):
wandb.termerror("keys argument must be a list of strings")
return []
last_step = self.lastHistoryStep
# set defaults for min/max step
if min_step is None:
min_step = 0
if max_step is None:
max_step = last_step + 1
# if the max step is past the actual last step, clamp it down
if max_step > last_step:
max_step = last_step + 1
if keys is None:
return public.HistoryScan(
run=self,
client=self.client,
page_size=page_size,
min_step=min_step,
max_step=max_step,
)
else:
return public.SampledHistoryScan(
run=self,
client=self.client,
keys=keys,
page_size=page_size,
min_step=min_step,
max_step=max_step,
)
@normalize_exceptions
def logged_artifacts(self, per_page: int = 100) -> public.RunArtifacts:
"""Fetches all artifacts logged by this run.
Retrieves all output artifacts that were logged during the run. Returns a
paginated result that can be iterated over or collected into a single list.
Args:
per_page: Number of artifacts to fetch per API request.
Returns:
An iterable collection of all Artifact objects logged as outputs during this run.
Example:
>>> import wandb
>>> import tempfile
>>> with tempfile.NamedTemporaryFile(
... mode="w", delete=False, suffix=".txt"
... ) as tmp:
... tmp.write("This is a test artifact")
... tmp_path = tmp.name
>>> run = wandb.init(project="artifact-example")
>>> artifact = wandb.Artifact("test_artifact", type="dataset")
>>> artifact.add_file(tmp_path)
>>> run.log_artifact(artifact)
>>> run.finish()
>>> api = wandb.Api()
>>> finished_run = api.run(f"{run.entity}/{run.project}/{run.id}")
>>> for logged_artifact in finished_run.logged_artifacts():
... print(logged_artifact.name)
test_artifact
"""
return public.RunArtifacts(self.client, self, mode="logged", per_page=per_page)
@normalize_exceptions
def used_artifacts(self, per_page: int = 100) -> public.RunArtifacts:
"""Fetches artifacts explicitly used by this run.
Retrieves only the input artifacts that were explicitly declared as used
during the run, typically via `run.use_artifact()`. Returns a paginated
result that can be iterated over or collected into a single list.
Args:
per_page: Number of artifacts to fetch per API request.
Returns:
An iterable collection of Artifact objects explicitly used as inputs in this run.
Example:
>>> import wandb
>>> run = wandb.init(project="artifact-example")
>>> run.use_artifact("test_artifact:latest")
>>> run.finish()
>>> api = wandb.Api()
>>> finished_run = api.run(f"{run.entity}/{run.project}/{run.id}")
>>> for used_artifact in finished_run.used_artifacts():
... print(used_artifact.name)
test_artifact
"""
return public.RunArtifacts(self.client, self, mode="used", per_page=per_page)
@normalize_exceptions
def use_artifact(self, artifact, use_as=None):
"""Declare an artifact as an input to a run.
Args:
artifact (`Artifact`): An artifact returned from
`wandb.Api().artifact(name)`
use_as (string, optional): A string identifying
how the artifact is used in the script. Used
to easily differentiate artifacts used in a
run, when using the beta wandb launch
feature's artifact swapping functionality.
Returns:
A `Artifact` object.
"""
api = InternalApi(
default_settings={"entity": self.entity, "project": self.project},
retry_timedelta=RETRY_TIMEDELTA,
)
api.set_current_run_id(self.id)
if isinstance(artifact, wandb.Artifact) and not artifact.is_draft():
api.use_artifact(
artifact.id,
use_as=use_as or artifact.name,
artifact_entity_name=artifact.entity,
artifact_project_name=artifact.project,
)
return artifact
elif isinstance(artifact, wandb.Artifact) and artifact.is_draft():
raise ValueError(
"Only existing artifacts are accepted by this api. "
"Manually create one with `wandb artifact put`"
)
else:
raise ValueError("You must pass a wandb.Api().artifact() to use_artifact")
@normalize_exceptions
def log_artifact(
self,
artifact: "wandb.Artifact",
aliases: Optional[Collection[str]] = None,
tags: Optional[Collection[str]] = None,
):
"""Declare an artifact as output of a run.
Args:
artifact (`Artifact`): An artifact returned from
`wandb.Api().artifact(name)`.
aliases (list, optional): Aliases to apply to this artifact.
tags: (list, optional) Tags to apply to this artifact, if any.
Returns:
A `Artifact` object.
"""
api = InternalApi(
default_settings={"entity": self.entity, "project": self.project},
retry_timedelta=RETRY_TIMEDELTA,
)
api.set_current_run_id(self.id)
if not isinstance(artifact, wandb.Artifact):
raise TypeError("You must pass a wandb.Api().artifact() to use_artifact")
if artifact.is_draft():
raise ValueError(
"Only existing artifacts are accepted by this api. "
"Manually create one with `wandb artifact put`"
)
if (
self.entity != artifact.source_entity
or self.project != artifact.source_project
):
raise ValueError("A run can't log an artifact to a different project.")
artifact_collection_name = artifact.source_name.split(":")[0]
api.create_artifact(
artifact.type,
artifact_collection_name,
artifact.digest,
aliases=aliases,
tags=tags,
)
return artifact
@property
def summary(self):
if self._summary is None:
from wandb.old.summary import HTTPSummary
# TODO: fix the outdir issue
self._summary = HTTPSummary(self, self.client, summary=self.summary_metrics)
return self._summary
@property
def path(self):
return [
urllib.parse.quote_plus(str(self.entity)),
urllib.parse.quote_plus(str(self.project)),
urllib.parse.quote_plus(str(self.id)),
]
@property
def url(self):
path = self.path
path.insert(2, "runs")
return self.client.app_url + "/".join(path)
@property
def metadata(self):
if self._metadata is None:
try:
f = self.file("wandb-metadata.json")
session = self.client._client.transport.session
response = session.get(f.url, timeout=5)
response.raise_for_status()
contents = response.content
self._metadata = json_util.loads(contents)
except: # noqa: E722
# file doesn't exist, or can't be downloaded, or can't be parsed
pass
return self._metadata
@property
def lastHistoryStep(self): # noqa: N802
query = gql(
"""
query RunHistoryKeys($project: String!, $entity: String!, $name: String!) {
project(name: $project, entityName: $entity) {
run(name: $name) { historyKeys }
}
}
"""
)
response = self._exec(query)
if (
response is None
or response.get("project") is None
or response["project"].get("run") is None
or response["project"]["run"].get("historyKeys") is None
):
return -1
history_keys = response["project"]["run"]["historyKeys"]
return history_keys["lastStep"] if "lastStep" in history_keys else -1
def to_html(self, height=420, hidden=False):
"""Generate HTML containing an iframe displaying this run."""
url = self.url + "?jupyter=true"
style = f"border:none;width:100%;height:{height}px;"
prefix = ""
if hidden:
style += "display:none;"
prefix = ipython.toggle_button()
return prefix + f"<iframe src={url!r} style={style!r}></iframe>"
def _repr_html_(self) -> str:
return self.to_html()
def __repr__(self):
return "<Run {} ({})>".format("/".join(self.path), self.state)
|