|
"""Use the Public API to export or update data that you have saved to W&B. |
|
|
|
Before using this API, you'll want to log data from your script β check the |
|
[Quickstart](https://docs.wandb.ai/quickstart) for more details. |
|
|
|
You might use the Public API to |
|
- update metadata or metrics for an experiment after it has been completed, |
|
- pull down your results as a dataframe for post-hoc analysis in a Jupyter notebook, or |
|
- check your saved model artifacts for those tagged as `ready-to-deploy`. |
|
|
|
For more on using the Public API, check out [our guide](https://docs.wandb.com/guides/track/public-api-guide). |
|
""" |
|
|
|
import json |
|
import logging |
|
import os |
|
import urllib |
|
from http import HTTPStatus |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
Dict, |
|
Iterator, |
|
List, |
|
Literal, |
|
Optional, |
|
Set, |
|
Union, |
|
) |
|
|
|
import requests |
|
from pydantic import ValidationError |
|
from typing_extensions import Unpack |
|
from wandb_gql import Client, gql |
|
from wandb_gql.client import RetryError |
|
|
|
import wandb |
|
from wandb import env, util |
|
from wandb._iterutils import one |
|
from wandb.apis import public |
|
from wandb.apis.normalize import normalize_exceptions |
|
from wandb.apis.public.const import RETRY_TIMEDELTA |
|
from wandb.apis.public.registries.registries_search import Registries |
|
from wandb.apis.public.registries.registry import Registry |
|
from wandb.apis.public.registries.utils import _fetch_org_entity_from_organization |
|
from wandb.apis.public.utils import ( |
|
PathType, |
|
fetch_org_from_settings_or_entity, |
|
gql_compat, |
|
parse_org_from_registry_path, |
|
) |
|
from wandb.proto.wandb_deprecated import Deprecated |
|
from wandb.proto.wandb_internal_pb2 import ServerFeature |
|
from wandb.sdk.artifacts._validators import is_artifact_registry_project |
|
from wandb.sdk.internal.internal_api import Api as InternalApi |
|
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings |
|
from wandb.sdk.launch.utils import LAUNCH_DEFAULT_PROJECT |
|
from wandb.sdk.lib import retry, runid |
|
from wandb.sdk.lib.deprecate import deprecate |
|
from wandb.sdk.lib.gql_request import GraphQLSession |
|
|
|
if TYPE_CHECKING: |
|
from wandb.automations import ( |
|
ActionType, |
|
Automation, |
|
EventType, |
|
Integration, |
|
NewAutomation, |
|
SlackIntegration, |
|
WebhookIntegration, |
|
) |
|
from wandb.automations._utils import WriteAutomationsKwargs |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class RetryingClient: |
|
INFO_QUERY = gql( |
|
""" |
|
query ServerInfo{ |
|
serverInfo { |
|
cliVersionInfo |
|
latestLocalVersionInfo { |
|
outOfDate |
|
latestVersionString |
|
versionOnThisInstanceString |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
|
|
def __init__(self, client: Client): |
|
self._server_info = None |
|
self._client = client |
|
|
|
@property |
|
def app_url(self): |
|
return util.app_url(self._client.transport.url.replace("/graphql", "")) + "/" |
|
|
|
@retry.retriable( |
|
retry_timedelta=RETRY_TIMEDELTA, |
|
check_retry_fn=util.no_retry_auth, |
|
retryable_exceptions=(RetryError, requests.RequestException), |
|
) |
|
def execute( |
|
self, *args, **kwargs |
|
): |
|
try: |
|
return self._client.execute(*args, **kwargs) |
|
except requests.exceptions.ReadTimeout: |
|
if "timeout" not in kwargs: |
|
timeout = self._client.transport.default_timeout |
|
wandb.termwarn( |
|
f"A graphql request initiated by the public wandb API timed out (timeout={timeout} sec). " |
|
f"Create a new API with an integer timeout larger than {timeout}, e.g., `api = wandb.Api(timeout={timeout + 10})` " |
|
f"to increase the graphql timeout." |
|
) |
|
raise |
|
|
|
@property |
|
def server_info(self): |
|
if self._server_info is None: |
|
self._server_info = self.execute(self.INFO_QUERY).get("serverInfo") |
|
return self._server_info |
|
|
|
def version_supported( |
|
self, min_version: str |
|
) -> bool: |
|
from packaging.version import parse |
|
|
|
return parse(min_version) <= parse( |
|
self.server_info["cliVersionInfo"]["max_cli_version"] |
|
) |
|
|
|
|
|
class Api: |
|
"""Used for querying the wandb server. |
|
|
|
Examples: |
|
Most common way to initialize |
|
>>> wandb.Api() |
|
|
|
Args: |
|
overrides: (dict) You can set `base_url` if you are using a wandb server |
|
other than https://api.wandb.ai. |
|
You can also set defaults for `entity`, `project`, and `run`. |
|
""" |
|
|
|
_HTTP_TIMEOUT = env.get_http_timeout(19) |
|
DEFAULT_ENTITY_QUERY = gql( |
|
""" |
|
query Viewer{ |
|
viewer { |
|
id |
|
entity |
|
} |
|
} |
|
""" |
|
) |
|
|
|
VIEWER_QUERY = gql( |
|
""" |
|
query Viewer{ |
|
viewer { |
|
id |
|
flags |
|
entity |
|
username |
|
email |
|
admin |
|
apiKeys { |
|
edges { |
|
node { |
|
id |
|
name |
|
description |
|
} |
|
} |
|
} |
|
teams { |
|
edges { |
|
node { |
|
name |
|
} |
|
} |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
USERS_QUERY = gql( |
|
""" |
|
query SearchUsers($query: String) { |
|
users(query: $query) { |
|
edges { |
|
node { |
|
id |
|
flags |
|
entity |
|
admin |
|
email |
|
deletedAt |
|
username |
|
apiKeys { |
|
edges { |
|
node { |
|
id |
|
name |
|
description |
|
} |
|
} |
|
} |
|
teams { |
|
edges { |
|
node { |
|
name |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
|
|
CREATE_PROJECT = gql( |
|
""" |
|
mutation upsertModel( |
|
$description: String |
|
$entityName: String |
|
$id: String |
|
$name: String |
|
$framework: String |
|
$access: String |
|
$views: JSONString |
|
) { |
|
upsertModel( |
|
input: { |
|
description: $description |
|
entityName: $entityName |
|
id: $id |
|
name: $name |
|
framework: $framework |
|
access: $access |
|
views: $views |
|
} |
|
) { |
|
project { |
|
id |
|
name |
|
entityName |
|
description |
|
access |
|
views |
|
} |
|
model { |
|
id |
|
name |
|
entityName |
|
description |
|
access |
|
views |
|
} |
|
inserted |
|
} |
|
} |
|
""" |
|
) |
|
|
|
def __init__( |
|
self, |
|
overrides: Optional[Dict[str, Any]] = None, |
|
timeout: Optional[int] = None, |
|
api_key: Optional[str] = None, |
|
) -> None: |
|
self.settings = InternalApi().settings() |
|
|
|
_overrides = overrides or {} |
|
self.settings.update(_overrides) |
|
self.settings["base_url"] = self.settings["base_url"].rstrip("/") |
|
if "organization" in _overrides: |
|
self.settings["organization"] = _overrides["organization"] |
|
if "username" in _overrides and "entity" not in _overrides: |
|
wandb.termwarn( |
|
'Passing "username" to Api is deprecated. please use "entity" instead.' |
|
) |
|
self.settings["entity"] = _overrides["username"] |
|
|
|
self._api_key = api_key |
|
if self.api_key is None and _thread_local_api_settings.cookies is None: |
|
wandb.login(host=_overrides.get("base_url")) |
|
|
|
self._viewer = None |
|
self._projects = {} |
|
self._runs = {} |
|
self._sweeps = {} |
|
self._reports = {} |
|
self._default_entity = None |
|
self._timeout = timeout if timeout is not None else self._HTTP_TIMEOUT |
|
auth = None |
|
if not _thread_local_api_settings.cookies: |
|
auth = ("api", self.api_key) |
|
proxies = self.settings.get("_proxies") or json.loads( |
|
os.environ.get("WANDB__PROXIES", "{}") |
|
) |
|
self._base_client = Client( |
|
transport=GraphQLSession( |
|
headers={ |
|
"User-Agent": self.user_agent, |
|
"Use-Admin-Privileges": "true", |
|
**(_thread_local_api_settings.headers or {}), |
|
}, |
|
use_json=True, |
|
|
|
|
|
timeout=self._timeout, |
|
auth=auth, |
|
url="{}/graphql".format(self.settings["base_url"]), |
|
cookies=_thread_local_api_settings.cookies, |
|
proxies=proxies, |
|
) |
|
) |
|
self._client = RetryingClient(self._base_client) |
|
|
|
def create_project(self, name: str, entity: str) -> None: |
|
"""Create a new project. |
|
|
|
Args: |
|
name: (str) The name of the new project. |
|
entity: (str) The entity of the new project. |
|
""" |
|
self.client.execute(self.CREATE_PROJECT, {"entityName": entity, "name": name}) |
|
|
|
def create_run( |
|
self, |
|
*, |
|
run_id: Optional[str] = None, |
|
project: Optional[str] = None, |
|
entity: Optional[str] = None, |
|
) -> "public.Run": |
|
"""Create a new run. |
|
|
|
Args: |
|
run_id: (str, optional) The ID to assign to the run, if given. The run ID is automatically generated by |
|
default, so in general, you do not need to specify this and should only do so at your own risk. |
|
project: (str, optional) If given, the project of the new run. |
|
entity: (str, optional) If given, the entity of the new run. |
|
|
|
Returns: |
|
The newly created `Run`. |
|
""" |
|
if entity is None: |
|
entity = self.default_entity |
|
return public.Run.create(self, run_id=run_id, project=project, entity=entity) |
|
|
|
def create_run_queue( |
|
self, |
|
name: str, |
|
type: "public.RunQueueResourceType", |
|
entity: Optional[str] = None, |
|
prioritization_mode: Optional["public.RunQueuePrioritizationMode"] = None, |
|
config: Optional[dict] = None, |
|
template_variables: Optional[dict] = None, |
|
) -> "public.RunQueue": |
|
"""Create a new run queue (launch). |
|
|
|
Args: |
|
name: (str) Name of the queue to create |
|
type: (str) Type of resource to be used for the queue. One of "local-container", "local-process", "kubernetes", "sagemaker", or "gcp-vertex". |
|
entity: (str) Optional name of the entity to create the queue. If None, will use the configured or default entity. |
|
prioritization_mode: (str) Optional version of prioritization to use. Either "V0" or None |
|
config: (dict) Optional default resource configuration to be used for the queue. Use handlebars (eg. `{{var}}`) to specify template variables. |
|
template_variables: (dict) A dictionary of template variable schemas to be used with the config. Expected format of: |
|
`{ |
|
"var-name": { |
|
"schema": { |
|
"type": ("string", "number", or "integer"), |
|
"default": (optional value), |
|
"minimum": (optional minimum), |
|
"maximum": (optional maximum), |
|
"enum": [..."(options)"] |
|
} |
|
} |
|
}` |
|
|
|
Returns: |
|
The newly created `RunQueue` |
|
|
|
Raises: |
|
ValueError if any of the parameters are invalid |
|
wandb.Error on wandb API errors |
|
""" |
|
|
|
|
|
if entity is None: |
|
entity = self.settings["entity"] or self.default_entity |
|
if entity is None: |
|
raise ValueError( |
|
"entity must be passed as a parameter, or set in settings" |
|
) |
|
|
|
if len(name) == 0: |
|
raise ValueError("name must be non-empty") |
|
if len(name) > 64: |
|
raise ValueError("name must be less than 64 characters") |
|
|
|
if type not in [ |
|
"local-container", |
|
"local-process", |
|
"kubernetes", |
|
"sagemaker", |
|
"gcp-vertex", |
|
]: |
|
raise ValueError( |
|
"resource_type must be one of 'local-container', 'local-process', 'kubernetes', 'sagemaker', or 'gcp-vertex'" |
|
) |
|
|
|
if prioritization_mode: |
|
prioritization_mode = prioritization_mode.upper() |
|
if prioritization_mode not in ["V0"]: |
|
raise ValueError("prioritization_mode must be 'V0' if present") |
|
|
|
if config is None: |
|
config = {} |
|
|
|
|
|
self.create_project(LAUNCH_DEFAULT_PROJECT, entity) |
|
|
|
api = InternalApi( |
|
default_settings={ |
|
"entity": entity, |
|
"project": self.project(LAUNCH_DEFAULT_PROJECT), |
|
}, |
|
retry_timedelta=RETRY_TIMEDELTA, |
|
) |
|
|
|
|
|
config_json = json.dumps({"resource_args": {type: config}}) |
|
create_config_result = api.create_default_resource_config( |
|
entity, type, config_json, template_variables |
|
) |
|
if not create_config_result["success"]: |
|
raise wandb.Error("failed to create default resource config") |
|
config_id = create_config_result["defaultResourceConfigID"] |
|
|
|
|
|
create_queue_result = api.create_run_queue( |
|
entity, |
|
LAUNCH_DEFAULT_PROJECT, |
|
name, |
|
"PROJECT", |
|
prioritization_mode, |
|
config_id, |
|
) |
|
if not create_queue_result["success"]: |
|
raise wandb.Error("failed to create run queue") |
|
|
|
return public.RunQueue( |
|
client=self.client, |
|
name=name, |
|
entity=entity, |
|
prioritization_mode=prioritization_mode, |
|
_access="PROJECT", |
|
_default_resource_config_id=config_id, |
|
_default_resource_config=config, |
|
) |
|
|
|
def upsert_run_queue( |
|
self, |
|
name: str, |
|
resource_config: dict, |
|
resource_type: "public.RunQueueResourceType", |
|
entity: Optional[str] = None, |
|
template_variables: Optional[dict] = None, |
|
external_links: Optional[dict] = None, |
|
prioritization_mode: Optional["public.RunQueuePrioritizationMode"] = None, |
|
): |
|
"""Upsert a run queue (launch). |
|
|
|
Args: |
|
name: (str) Name of the queue to create |
|
entity: (str) Optional name of the entity to create the queue. If None, will use the configured or default entity. |
|
resource_config: (dict) Optional default resource configuration to be used for the queue. Use handlebars (eg. `{{var}}`) to specify template variables. |
|
resource_type: (str) Type of resource to be used for the queue. One of "local-container", "local-process", "kubernetes", "sagemaker", or "gcp-vertex". |
|
template_variables: (dict) A dictionary of template variable schemas to be used with the config. Expected format of: |
|
`{ |
|
"var-name": { |
|
"schema": { |
|
"type": ("string", "number", or "integer"), |
|
"default": (optional value), |
|
"minimum": (optional minimum), |
|
"maximum": (optional maximum), |
|
"enum": [..."(options)"] |
|
} |
|
} |
|
}` |
|
external_links: (dict) Optional dictionary of external links to be used with the queue. Expected format of: |
|
`{ |
|
"name": "url" |
|
}` |
|
prioritization_mode: (str) Optional version of prioritization to use. Either "V0" or None |
|
|
|
Returns: |
|
The upserted `RunQueue`. |
|
|
|
Raises: |
|
ValueError if any of the parameters are invalid |
|
wandb.Error on wandb API errors |
|
""" |
|
if entity is None: |
|
entity = self.settings["entity"] or self.default_entity |
|
if entity is None: |
|
raise ValueError( |
|
"entity must be passed as a parameter, or set in settings" |
|
) |
|
|
|
if len(name) == 0: |
|
raise ValueError("name must be non-empty") |
|
if len(name) > 64: |
|
raise ValueError("name must be less than 64 characters") |
|
|
|
prioritization_mode = prioritization_mode or "DISABLED" |
|
prioritization_mode = prioritization_mode.upper() |
|
if prioritization_mode not in ["V0", "DISABLED"]: |
|
raise ValueError( |
|
"prioritization_mode must be 'V0' or 'DISABLED' if present" |
|
) |
|
|
|
if resource_type not in [ |
|
"local-container", |
|
"local-process", |
|
"kubernetes", |
|
"sagemaker", |
|
"gcp-vertex", |
|
]: |
|
raise ValueError( |
|
"resource_type must be one of 'local-container', 'local-process', 'kubernetes', 'sagemaker', or 'gcp-vertex'" |
|
) |
|
|
|
self.create_project(LAUNCH_DEFAULT_PROJECT, entity) |
|
api = InternalApi( |
|
default_settings={ |
|
"entity": entity, |
|
"project": self.project(LAUNCH_DEFAULT_PROJECT), |
|
}, |
|
retry_timedelta=RETRY_TIMEDELTA, |
|
) |
|
|
|
|
|
external_links = external_links or {} |
|
external_links = { |
|
"links": [ |
|
{ |
|
"label": key, |
|
"url": value, |
|
} |
|
for key, value in external_links.items() |
|
] |
|
} |
|
upsert_run_queue_result = api.upsert_run_queue( |
|
name, |
|
entity, |
|
resource_type, |
|
{"resource_args": {resource_type: resource_config}}, |
|
template_variables=template_variables, |
|
external_links=external_links, |
|
prioritization_mode=prioritization_mode, |
|
) |
|
if not upsert_run_queue_result["success"]: |
|
raise wandb.Error("failed to create run queue") |
|
schema_errors = ( |
|
upsert_run_queue_result.get("configSchemaValidationErrors") or [] |
|
) |
|
for error in schema_errors: |
|
wandb.termwarn(f"resource config validation: {error}") |
|
|
|
return public.RunQueue( |
|
client=self.client, |
|
name=name, |
|
entity=entity, |
|
) |
|
|
|
def create_user(self, email, admin=False): |
|
"""Create a new user. |
|
|
|
Args: |
|
email: (str) The email address of the user |
|
admin: (bool) Whether this user should be a global instance admin |
|
|
|
Returns: |
|
A `User` object |
|
""" |
|
return public.User.create(self, email, admin) |
|
|
|
def sync_tensorboard(self, root_dir, run_id=None, project=None, entity=None): |
|
"""Sync a local directory containing tfevent files to wandb.""" |
|
from wandb.sync import SyncManager |
|
|
|
run_id = run_id or runid.generate_id() |
|
project = project or self.settings.get("project") or "uncategorized" |
|
entity = entity or self.default_entity |
|
|
|
sm = SyncManager( |
|
project=project, |
|
entity=entity, |
|
run_id=run_id, |
|
mark_synced=False, |
|
app_url=self.client.app_url, |
|
view=False, |
|
verbose=False, |
|
sync_tensorboard=True, |
|
) |
|
sm.add(root_dir) |
|
sm.start() |
|
while not sm.is_done(): |
|
_ = sm.poll() |
|
return self.run("/".join([entity, project, run_id])) |
|
|
|
@property |
|
def client(self) -> RetryingClient: |
|
return self._client |
|
|
|
@property |
|
def user_agent(self) -> str: |
|
return "W&B Public Client {}".format(wandb.__version__) |
|
|
|
@property |
|
def api_key(self) -> Optional[str]: |
|
|
|
if _thread_local_api_settings.api_key: |
|
return _thread_local_api_settings.api_key |
|
if self._api_key is not None: |
|
return self._api_key |
|
auth = requests.utils.get_netrc_auth(self.settings["base_url"]) |
|
key = None |
|
if auth: |
|
key = auth[-1] |
|
|
|
if os.getenv("WANDB_API_KEY"): |
|
key = os.environ["WANDB_API_KEY"] |
|
self._api_key = key |
|
return key |
|
|
|
@property |
|
def default_entity(self) -> Optional[str]: |
|
if self._default_entity is None: |
|
res = self._client.execute(self.DEFAULT_ENTITY_QUERY) |
|
self._default_entity = (res.get("viewer") or {}).get("entity") |
|
return self._default_entity |
|
|
|
@property |
|
def viewer(self) -> "public.User": |
|
if self._viewer is None: |
|
self._viewer = public.User( |
|
self._client, self._client.execute(self.VIEWER_QUERY).get("viewer") |
|
) |
|
self._default_entity = self._viewer.entity |
|
return self._viewer |
|
|
|
def flush(self): |
|
"""Flush the local cache. |
|
|
|
The api object keeps a local cache of runs, so if the state of the run may |
|
change while executing your script you must clear the local cache with |
|
`api.flush()` to get the latest values associated with the run. |
|
""" |
|
self._runs = {} |
|
|
|
def from_path(self, path): |
|
"""Return a run, sweep, project or report from a path. |
|
|
|
Examples: |
|
``` |
|
project = api.from_path("my_project") |
|
team_project = api.from_path("my_team/my_project") |
|
run = api.from_path("my_team/my_project/runs/id") |
|
sweep = api.from_path("my_team/my_project/sweeps/id") |
|
report = api.from_path("my_team/my_project/reports/My-Report-Vm11dsdf") |
|
``` |
|
|
|
Args: |
|
path: (str) The path to the project, run, sweep or report |
|
|
|
Returns: |
|
A `Project`, `Run`, `Sweep`, or `BetaReport` instance. |
|
|
|
Raises: |
|
wandb.Error if path is invalid or the object doesn't exist |
|
""" |
|
parts = path.strip("/ ").split("/") |
|
if len(parts) == 1: |
|
return self.project(path) |
|
elif len(parts) == 2: |
|
return self.project(parts[1], parts[0]) |
|
elif len(parts) == 3: |
|
return self.run(path) |
|
elif len(parts) == 4: |
|
if parts[2].startswith("run"): |
|
return self.run(path) |
|
elif parts[2].startswith("sweep"): |
|
return self.sweep(path) |
|
elif parts[2].startswith("report"): |
|
if "--" not in parts[-1]: |
|
if "-" in parts[-1]: |
|
raise wandb.Error( |
|
"Invalid report path, should be team/project/reports/Name--XXXX" |
|
) |
|
else: |
|
parts[-1] = "--" + parts[-1] |
|
name, id = parts[-1].split("--") |
|
return public.BetaReport( |
|
self.client, |
|
{ |
|
"display_name": urllib.parse.unquote(name.replace("-", " ")), |
|
"id": id, |
|
"spec": "{}", |
|
}, |
|
parts[0], |
|
parts[1], |
|
) |
|
raise wandb.Error( |
|
"Invalid path, should be TEAM/PROJECT/TYPE/ID where TYPE is runs, sweeps, or reports" |
|
) |
|
|
|
def _parse_project_path(self, path): |
|
"""Return project and entity for project specified by path.""" |
|
project = self.settings["project"] or "uncategorized" |
|
entity = self.settings["entity"] or self.default_entity |
|
if path is None: |
|
return entity, project |
|
parts = path.split("/", 1) |
|
if len(parts) == 1: |
|
return entity, path |
|
return parts |
|
|
|
def _parse_path(self, path): |
|
"""Parse url, filepath, or docker paths. |
|
|
|
Allows paths in the following formats: |
|
- url: entity/project/runs/id |
|
- path: entity/project/id |
|
- docker: entity/project:id |
|
|
|
Entity is optional and will fall back to the current logged-in user. |
|
""" |
|
project = self.settings["project"] or "uncategorized" |
|
entity = self.settings["entity"] or self.default_entity |
|
parts = ( |
|
path.replace("/runs/", "/").replace("/sweeps/", "/").strip("/ ").split("/") |
|
) |
|
if ":" in parts[-1]: |
|
id = parts[-1].split(":")[-1] |
|
parts[-1] = parts[-1].split(":")[0] |
|
elif parts[-1]: |
|
id = parts[-1] |
|
if len(parts) == 1 and project != "uncategorized": |
|
pass |
|
elif len(parts) > 1: |
|
project = parts[1] |
|
if entity and id == project: |
|
project = parts[0] |
|
else: |
|
entity = parts[0] |
|
if len(parts) == 3: |
|
entity = parts[0] |
|
else: |
|
project = parts[0] |
|
return entity, project, id |
|
|
|
def _parse_artifact_path(self, path): |
|
"""Return project, entity and artifact name for project specified by path.""" |
|
project = self.settings["project"] or "uncategorized" |
|
entity = self.settings["entity"] or self.default_entity |
|
if path is None: |
|
return entity, project |
|
|
|
path, colon, alias = path.partition(":") |
|
full_alias = colon + alias |
|
|
|
parts = path.split("/") |
|
if len(parts) > 3: |
|
raise ValueError("Invalid artifact path: {}".format(path)) |
|
elif len(parts) == 1: |
|
return entity, project, path + full_alias |
|
elif len(parts) == 2: |
|
return entity, parts[0], parts[1] + full_alias |
|
parts[-1] += full_alias |
|
return parts |
|
|
|
def projects( |
|
self, entity: Optional[str] = None, per_page: int = 200 |
|
) -> "public.Projects": |
|
"""Get projects for a given entity. |
|
|
|
Args: |
|
entity: (str) Name of the entity requested. If None, will fall back to the |
|
default entity passed to `Api`. If no default entity, will raise a `ValueError`. |
|
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this. |
|
|
|
Returns: |
|
A `Projects` object which is an iterable collection of `Project` objects. |
|
""" |
|
if entity is None: |
|
entity = self.settings["entity"] or self.default_entity |
|
if entity is None: |
|
raise ValueError( |
|
"entity must be passed as a parameter, or set in settings" |
|
) |
|
if entity not in self._projects: |
|
self._projects[entity] = public.Projects( |
|
self.client, entity, per_page=per_page |
|
) |
|
return self._projects[entity] |
|
|
|
def project(self, name: str, entity: Optional[str] = None) -> "public.Project": |
|
"""Return the `Project` with the given name (and entity, if given). |
|
|
|
Args: |
|
name: (str) The project name. |
|
entity: (str) Name of the entity requested. If None, will fall back to the |
|
default entity passed to `Api`. If no default entity, will raise a `ValueError`. |
|
|
|
Returns: |
|
A `Project` object. |
|
""" |
|
|
|
org = entity if is_artifact_registry_project(name) else "" |
|
|
|
if entity is None: |
|
entity = self.settings["entity"] or self.default_entity |
|
|
|
|
|
if is_artifact_registry_project(name): |
|
settings_entity = self.settings["entity"] or self.default_entity |
|
entity = InternalApi()._resolve_org_entity_name( |
|
entity=settings_entity, organization=org |
|
) |
|
return public.Project(self.client, entity, name, {}) |
|
|
|
def reports( |
|
self, path: str = "", name: Optional[str] = None, per_page: int = 50 |
|
) -> "public.Reports": |
|
"""Get reports for a given project path. |
|
|
|
WARNING: This api is in beta and will likely change in a future release |
|
|
|
Args: |
|
path: (str) path to project the report resides in, should be in the form: "entity/project" |
|
name: (str, optional) optional name of the report requested. |
|
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this. |
|
|
|
Returns: |
|
A `Reports` object which is an iterable collection of `BetaReport` objects. |
|
""" |
|
entity, project, _ = self._parse_path(path + "/fake_run") |
|
|
|
if name: |
|
name = urllib.parse.unquote(name) |
|
key = "/".join([entity, project, str(name)]) |
|
else: |
|
key = "/".join([entity, project]) |
|
|
|
if key not in self._reports: |
|
self._reports[key] = public.Reports( |
|
self.client, |
|
public.Project(self.client, entity, project, {}), |
|
name=name, |
|
per_page=per_page, |
|
) |
|
return self._reports[key] |
|
|
|
def create_team(self, team, admin_username=None): |
|
"""Create a new team. |
|
|
|
Args: |
|
team: (str) The name of the team |
|
admin_username: (str) optional username of the admin user of the team, defaults to the current user. |
|
|
|
Returns: |
|
A `Team` object |
|
""" |
|
return public.Team.create(self, team, admin_username) |
|
|
|
def team(self, team: str) -> "public.Team": |
|
"""Return the matching `Team` with the given name. |
|
|
|
Args: |
|
team: (str) The name of the team. |
|
|
|
Returns: |
|
A `Team` object. |
|
""" |
|
return public.Team(self.client, team) |
|
|
|
def user(self, username_or_email: str) -> Optional["public.User"]: |
|
"""Return a user from a username or email address. |
|
|
|
Note: This function only works for Local Admins, if you are trying to get your own user object, please use `api.viewer`. |
|
|
|
Args: |
|
username_or_email: (str) The username or email address of the user |
|
|
|
Returns: |
|
A `User` object or None if a user couldn't be found |
|
""" |
|
res = self._client.execute(self.USERS_QUERY, {"query": username_or_email}) |
|
if len(res["users"]["edges"]) == 0: |
|
return None |
|
elif len(res["users"]["edges"]) > 1: |
|
wandb.termwarn( |
|
"Found multiple users, returning the first user matching {}".format( |
|
username_or_email |
|
) |
|
) |
|
return public.User(self._client, res["users"]["edges"][0]["node"]) |
|
|
|
def users(self, username_or_email: str) -> List["public.User"]: |
|
"""Return all users from a partial username or email address query. |
|
|
|
Note: This function only works for Local Admins, if you are trying to get your own user object, please use `api.viewer`. |
|
|
|
Args: |
|
username_or_email: (str) The prefix or suffix of the user you want to find |
|
|
|
Returns: |
|
An array of `User` objects |
|
""" |
|
res = self._client.execute(self.USERS_QUERY, {"query": username_or_email}) |
|
return [ |
|
public.User(self._client, edge["node"]) for edge in res["users"]["edges"] |
|
] |
|
|
|
def runs( |
|
self, |
|
path: Optional[str] = None, |
|
filters: Optional[Dict[str, Any]] = None, |
|
order: str = "+created_at", |
|
per_page: int = 50, |
|
include_sweeps: bool = True, |
|
): |
|
"""Return a set of runs from a project that match the filters provided. |
|
|
|
Fields you can filter by include: |
|
- `createdAt`: The timestamp when the run was created. (in ISO 8601 format, e.g. "2023-01-01T12:00:00Z") |
|
- `displayName`: The human-readable display name of the run. (e.g. "eager-fox-1") |
|
- `duration`: The total runtime of the run in seconds. |
|
- `group`: The group name used to organize related runs together. |
|
- `host`: The hostname where the run was executed. |
|
- `jobType`: The type of job or purpose of the run. |
|
- `name`: The unique identifier of the run. (e.g. "a1b2cdef") |
|
- `state`: The current state of the run. |
|
- `tags`: The tags associated with the run. |
|
- `username`: The username of the user who initiated the run |
|
|
|
Additionally, you can filter by items in the run config or summary metrics. |
|
Such as `config.experiment_name`, `summary_metrics.loss`, etc. |
|
|
|
For more complex filtering, you can use MongoDB query operators. |
|
For details, see: https://docs.mongodb.com/manual/reference/operator/query |
|
The following operations are supported: |
|
- `$and` |
|
- `$or` |
|
- `$nor` |
|
- `$eq` |
|
- `$ne` |
|
- `$gt` |
|
- `$gte` |
|
- `$lt` |
|
- `$lte` |
|
- `$in` |
|
- `$nin` |
|
- `$exists` |
|
- `$regex` |
|
|
|
|
|
Examples: |
|
Find runs in my_project where config.experiment_name has been set to "foo" |
|
``` |
|
api.runs( |
|
path="my_entity/my_project", |
|
filters={"config.experiment_name": "foo"}, |
|
) |
|
``` |
|
|
|
Find runs in my_project where config.experiment_name has been set to "foo" or "bar" |
|
``` |
|
api.runs( |
|
path="my_entity/my_project", |
|
filters={ |
|
"$or": [ |
|
{"config.experiment_name": "foo"}, |
|
{"config.experiment_name": "bar"}, |
|
] |
|
}, |
|
) |
|
``` |
|
|
|
Find runs in my_project where config.experiment_name matches a regex (anchors are not supported) |
|
``` |
|
api.runs( |
|
path="my_entity/my_project", |
|
filters={"config.experiment_name": {"$regex": "b.*"}}, |
|
) |
|
``` |
|
|
|
Find runs in my_project where the run name matches a regex (anchors are not supported) |
|
``` |
|
api.runs( |
|
path="my_entity/my_project", |
|
filters={"display_name": {"$regex": "^foo.*"}}, |
|
) |
|
``` |
|
|
|
Find runs in my_project where config.experiment contains a nested field "category" with value "testing" |
|
``` |
|
api.runs( |
|
path="my_entity/my_project", |
|
filters={"config.experiment.category": "testing"}, |
|
) |
|
``` |
|
|
|
Find runs in my_project with a loss value of 0.5 nested in a dictionary under model1 in the summary metrics |
|
``` |
|
api.runs( |
|
path="my_entity/my_project", |
|
filters={"summary_metrics.model1.loss": 0.5}, |
|
) |
|
``` |
|
|
|
Find runs in my_project sorted by ascending loss |
|
``` |
|
api.runs(path="my_entity/my_project", order="+summary_metrics.loss") |
|
``` |
|
|
|
Args: |
|
path: (str) path to project, should be in the form: "entity/project" |
|
filters: (dict) queries for specific runs using the MongoDB query language. |
|
You can filter by run properties such as config.key, summary_metrics.key, state, entity, createdAt, etc. |
|
For example: `{"config.experiment_name": "foo"}` would find runs with a config entry |
|
of experiment name set to "foo" |
|
order: (str) Order can be `created_at`, `heartbeat_at`, `config.*.value`, or `summary_metrics.*`. |
|
If you prepend order with a + order is ascending. |
|
If you prepend order with a - order is descending (default). |
|
The default order is run.created_at from oldest to newest. |
|
per_page: (int) Sets the page size for query pagination. |
|
include_sweeps: (bool) Whether to include the sweep runs in the results. |
|
|
|
Returns: |
|
A `Runs` object, which is an iterable collection of `Run` objects. |
|
""" |
|
entity, project = self._parse_project_path(path) |
|
filters = filters or {} |
|
key = (path or "") + str(filters) + str(order) |
|
if not self._runs.get(key): |
|
self._runs[key] = public.Runs( |
|
self.client, |
|
entity, |
|
project, |
|
filters=filters, |
|
order=order, |
|
per_page=per_page, |
|
include_sweeps=include_sweeps, |
|
) |
|
return self._runs[key] |
|
|
|
@normalize_exceptions |
|
def run(self, path=""): |
|
"""Return a single run by parsing path in the form entity/project/run_id. |
|
|
|
Args: |
|
path: (str) path to run in the form `entity/project/run_id`. |
|
If `api.entity` is set, this can be in the form `project/run_id` |
|
and if `api.project` is set this can just be the run_id. |
|
|
|
Returns: |
|
A `Run` object. |
|
""" |
|
entity, project, run_id = self._parse_path(path) |
|
if not self._runs.get(path): |
|
self._runs[path] = public.Run(self.client, entity, project, run_id) |
|
return self._runs[path] |
|
|
|
def queued_run( |
|
self, |
|
entity, |
|
project, |
|
queue_name, |
|
run_queue_item_id, |
|
project_queue=None, |
|
priority=None, |
|
): |
|
"""Return a single queued run based on the path. |
|
|
|
Parses paths of the form entity/project/queue_id/run_queue_item_id. |
|
""" |
|
return public.QueuedRun( |
|
self.client, |
|
entity, |
|
project, |
|
queue_name, |
|
run_queue_item_id, |
|
project_queue=project_queue, |
|
priority=priority, |
|
) |
|
|
|
def run_queue( |
|
self, |
|
entity, |
|
name, |
|
): |
|
"""Return the named `RunQueue` for entity. |
|
|
|
To create a new `RunQueue`, use `wandb.Api().create_run_queue(...)`. |
|
""" |
|
return public.RunQueue( |
|
self.client, |
|
name, |
|
entity, |
|
) |
|
|
|
@normalize_exceptions |
|
def sweep(self, path=""): |
|
"""Return a sweep by parsing path in the form `entity/project/sweep_id`. |
|
|
|
Args: |
|
path: (str, optional) path to sweep in the form entity/project/sweep_id. If `api.entity` |
|
is set, this can be in the form project/sweep_id and if `api.project` is set |
|
this can just be the sweep_id. |
|
|
|
Returns: |
|
A `Sweep` object. |
|
""" |
|
entity, project, sweep_id = self._parse_path(path) |
|
if not self._sweeps.get(path): |
|
self._sweeps[path] = public.Sweep(self.client, entity, project, sweep_id) |
|
return self._sweeps[path] |
|
|
|
@normalize_exceptions |
|
def artifact_types(self, project: Optional[str] = None) -> "public.ArtifactTypes": |
|
"""Return a collection of matching artifact types. |
|
|
|
Args: |
|
project: (str, optional) If given, a project name or path to filter on. |
|
|
|
Returns: |
|
An iterable `ArtifactTypes` object. |
|
""" |
|
project_path = project |
|
entity, project = self._parse_project_path(project_path) |
|
|
|
if is_artifact_registry_project(project): |
|
settings_entity = self.settings["entity"] or self.default_entity |
|
org = parse_org_from_registry_path(project_path, PathType.PROJECT) |
|
entity = InternalApi()._resolve_org_entity_name( |
|
entity=settings_entity, organization=org |
|
) |
|
return public.ArtifactTypes(self.client, entity, project) |
|
|
|
@normalize_exceptions |
|
def artifact_type( |
|
self, type_name: str, project: Optional[str] = None |
|
) -> "public.ArtifactType": |
|
"""Return the matching `ArtifactType`. |
|
|
|
Args: |
|
type_name: (str) The name of the artifact type to retrieve. |
|
project: (str, optional) If given, a project name or path to filter on. |
|
|
|
Returns: |
|
An `ArtifactType` object. |
|
""" |
|
project_path = project |
|
entity, project = self._parse_project_path(project_path) |
|
|
|
if is_artifact_registry_project(project): |
|
org = parse_org_from_registry_path(project_path, PathType.PROJECT) |
|
settings_entity = self.settings["entity"] or self.default_entity |
|
entity = InternalApi()._resolve_org_entity_name( |
|
entity=settings_entity, organization=org |
|
) |
|
return public.ArtifactType(self.client, entity, project, type_name) |
|
|
|
@normalize_exceptions |
|
def artifact_collections( |
|
self, project_name: str, type_name: str, per_page: int = 50 |
|
) -> "public.ArtifactCollections": |
|
"""Return a collection of matching artifact collections. |
|
|
|
Args: |
|
project_name: (str) The name of the project to filter on. |
|
type_name: (str) The name of the artifact type to filter on. |
|
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this. |
|
|
|
Returns: |
|
An iterable `ArtifactCollections` object. |
|
""" |
|
entity, project = self._parse_project_path(project_name) |
|
|
|
if is_artifact_registry_project(project): |
|
org = parse_org_from_registry_path(project_name, PathType.PROJECT) |
|
settings_entity = self.settings["entity"] or self.default_entity |
|
entity = InternalApi()._resolve_org_entity_name( |
|
entity=settings_entity, organization=org |
|
) |
|
return public.ArtifactCollections( |
|
self.client, entity, project, type_name, per_page=per_page |
|
) |
|
|
|
@normalize_exceptions |
|
def artifact_collection( |
|
self, type_name: str, name: str |
|
) -> "public.ArtifactCollection": |
|
"""Return a single artifact collection by type and parsing path in the form `entity/project/name`. |
|
|
|
Args: |
|
type_name: (str) The type of artifact collection to fetch. |
|
name: (str) An artifact collection name. May be prefixed with entity/project. |
|
|
|
Returns: |
|
An `ArtifactCollection` object. |
|
""" |
|
entity, project, collection_name = self._parse_artifact_path(name) |
|
|
|
if is_artifact_registry_project(project): |
|
org = parse_org_from_registry_path(name, PathType.ARTIFACT) |
|
settings_entity = self.settings["entity"] or self.default_entity |
|
entity = InternalApi()._resolve_org_entity_name( |
|
entity=settings_entity, organization=org |
|
) |
|
|
|
if entity is None: |
|
raise ValueError( |
|
"Could not determine entity. Please include the entity as part of the collection name path." |
|
) |
|
|
|
return public.ArtifactCollection( |
|
self.client, entity, project, collection_name, type_name |
|
) |
|
|
|
@normalize_exceptions |
|
def artifact_versions(self, type_name, name, per_page=50): |
|
"""Deprecated, use `artifacts(type_name, name)` instead.""" |
|
deprecate( |
|
field_name=Deprecated.api__artifact_versions, |
|
warning_message=( |
|
"Api.artifact_versions(type_name, name) is deprecated, " |
|
"use Api.artifacts(type_name, name) instead." |
|
), |
|
) |
|
return self.artifacts(type_name, name, per_page=per_page) |
|
|
|
@normalize_exceptions |
|
def artifacts( |
|
self, |
|
type_name: str, |
|
name: str, |
|
per_page: int = 50, |
|
tags: Optional[List[str]] = None, |
|
) -> "public.Artifacts": |
|
"""Return an `Artifacts` collection from the given parameters. |
|
|
|
Args: |
|
type_name: (str) The type of artifacts to fetch. |
|
name: (str) An artifact collection name. May be prefixed with entity/project. |
|
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this. |
|
tags: (list[str], optional) Only return artifacts with all of these tags. |
|
|
|
Returns: |
|
An iterable `Artifacts` object. |
|
""" |
|
entity, project, collection_name = self._parse_artifact_path(name) |
|
|
|
if is_artifact_registry_project(project): |
|
org = parse_org_from_registry_path(name, PathType.ARTIFACT) |
|
settings_entity = self.settings["entity"] or self.default_entity |
|
entity = InternalApi()._resolve_org_entity_name( |
|
entity=settings_entity, organization=org |
|
) |
|
return public.Artifacts( |
|
self.client, |
|
entity, |
|
project, |
|
collection_name, |
|
type_name, |
|
per_page=per_page, |
|
tags=tags, |
|
) |
|
|
|
@normalize_exceptions |
|
def _artifact( |
|
self, name: str, type: Optional[str] = None, enable_tracking: bool = False |
|
): |
|
if name is None: |
|
raise ValueError("You must specify name= to fetch an artifact.") |
|
entity, project, artifact_name = self._parse_artifact_path(name) |
|
|
|
|
|
if is_artifact_registry_project(project): |
|
organization = ( |
|
name.split("/")[0] |
|
if name.count("/") == 2 |
|
else self.settings["organization"] |
|
) |
|
|
|
settings_entity = self.settings["entity"] or self.default_entity |
|
|
|
|
|
entity = InternalApi()._resolve_org_entity_name( |
|
entity=settings_entity, organization=organization |
|
) |
|
|
|
if entity is None: |
|
raise ValueError( |
|
"Could not determine entity. Please include the entity as part of the artifact name path." |
|
) |
|
|
|
artifact = wandb.Artifact._from_name( |
|
entity=entity, |
|
project=project, |
|
name=artifact_name, |
|
client=self.client, |
|
enable_tracking=enable_tracking, |
|
) |
|
if type is not None and artifact.type != type: |
|
raise ValueError( |
|
f"type {type} specified but this artifact is of type {artifact.type}" |
|
) |
|
return artifact |
|
|
|
@normalize_exceptions |
|
def artifact(self, name: str, type: Optional[str] = None): |
|
"""Return a single artifact by parsing path in the form `project/name` or `entity/project/name`. |
|
|
|
Args: |
|
name: (str) An artifact name. May be prefixed with project/ or entity/project/. |
|
If no entity is specified in the name, the Run or API setting's entity is used. |
|
Valid names can be in the following forms: |
|
name:version |
|
name:alias |
|
type: (str, optional) The type of artifact to fetch. |
|
|
|
Returns: |
|
An `Artifact` object. |
|
|
|
Raises: |
|
ValueError: If the artifact name is not specified. |
|
ValueError: If the artifact type is specified but does not match the type of the fetched artifact. |
|
|
|
Note: |
|
This method is intended for external use only. Do not call `api.artifact()` within the wandb repository code. |
|
""" |
|
return self._artifact(name=name, type=type, enable_tracking=True) |
|
|
|
@normalize_exceptions |
|
def job(self, name: Optional[str], path: Optional[str] = None) -> "public.Job": |
|
"""Return a `Job` from the given parameters. |
|
|
|
Args: |
|
name: (str) The job name. |
|
path: (str, optional) If given, the root path in which to download the job artifact. |
|
|
|
Returns: |
|
A `Job` object. |
|
""" |
|
if name is None: |
|
raise ValueError("You must specify name= to fetch a job.") |
|
elif name.count("/") != 2 or ":" not in name: |
|
raise ValueError( |
|
"Invalid job specification. A job must be of the form: <entity>/<project>/<job-name>:<alias-or-version>" |
|
) |
|
return public.Job(self, name, path) |
|
|
|
@normalize_exceptions |
|
def list_jobs(self, entity: str, project: str) -> List[Dict[str, Any]]: |
|
"""Return a list of jobs, if any, for the given entity and project. |
|
|
|
Args: |
|
entity: (str) The entity for the listed job(s). |
|
project: (str) The project for the listed job(s). |
|
|
|
Returns: |
|
A list of matching jobs. |
|
""" |
|
if entity is None: |
|
raise ValueError("Specify an entity when listing jobs") |
|
if project is None: |
|
raise ValueError("Specify a project when listing jobs") |
|
|
|
query = gql( |
|
""" |
|
query ArtifactOfType( |
|
$entityName: String!, |
|
$projectName: String!, |
|
$artifactTypeName: String!, |
|
) { |
|
project(name: $projectName, entityName: $entityName) { |
|
artifactType(name: $artifactTypeName) { |
|
artifactCollections { |
|
edges { |
|
node { |
|
artifacts { |
|
edges { |
|
node { |
|
id |
|
state |
|
aliases { |
|
alias |
|
} |
|
artifactSequence { |
|
name |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
|
|
try: |
|
artifact_query = self._client.execute( |
|
query, |
|
{ |
|
"projectName": project, |
|
"entityName": entity, |
|
"artifactTypeName": "job", |
|
}, |
|
) |
|
|
|
if not artifact_query or not artifact_query["project"]: |
|
wandb.termerror( |
|
f"Project: '{project}' not found in entity: '{entity}' or access denied." |
|
) |
|
return [] |
|
|
|
if artifact_query["project"]["artifactType"] is None: |
|
return [] |
|
|
|
artifacts = artifact_query["project"]["artifactType"][ |
|
"artifactCollections" |
|
]["edges"] |
|
|
|
return [x["node"]["artifacts"] for x in artifacts] |
|
except requests.exceptions.HTTPError: |
|
return False |
|
|
|
@normalize_exceptions |
|
def artifact_exists(self, name: str, type: Optional[str] = None) -> bool: |
|
"""Return whether an artifact version exists within a specified project and entity. |
|
|
|
Args: |
|
name: (str) An artifact name. May be prefixed with entity/project. |
|
If entity or project is not specified, it will be inferred from the override params if populated. |
|
Otherwise, entity will be pulled from the user settings and project will default to "uncategorized". |
|
Valid names can be in the following forms: |
|
name:version |
|
name:alias |
|
type: (str, optional) The type of artifact |
|
|
|
Returns: |
|
True if the artifact version exists, False otherwise. |
|
""" |
|
try: |
|
self._artifact(name, type) |
|
except wandb.errors.CommError: |
|
return False |
|
|
|
return True |
|
|
|
@normalize_exceptions |
|
def artifact_collection_exists(self, name: str, type: str) -> bool: |
|
"""Return whether an artifact collection exists within a specified project and entity. |
|
|
|
Args: |
|
name: (str) An artifact collection name. May be prefixed with entity/project. |
|
If entity or project is not specified, it will be inferred from the override params if populated. |
|
Otherwise, entity will be pulled from the user settings and project will default to "uncategorized". |
|
type: (str) The type of artifact collection |
|
|
|
Returns: |
|
True if the artifact collection exists, False otherwise. |
|
""" |
|
try: |
|
self.artifact_collection(type, name) |
|
except wandb.errors.CommError: |
|
return False |
|
|
|
return True |
|
|
|
def registries( |
|
self, |
|
organization: Optional[str] = None, |
|
filter: Optional[Dict[str, Any]] = None, |
|
) -> Registries: |
|
"""Returns a Registry iterator. |
|
|
|
Use the iterator to search and filter registries, collections, |
|
or artifact versions across your organization's registry. |
|
|
|
Examples: |
|
Find all registries with the names that contain "model" |
|
```python |
|
import wandb |
|
|
|
api = wandb.Api() # specify an org if your entity belongs to multiple orgs |
|
api.registries(filter={"name": {"$regex": "model"}}) |
|
``` |
|
|
|
Find all collections in the registries with the name "my_collection" and the tag "my_tag" |
|
```python |
|
api.registries().collections(filter={"name": "my_collection", "tag": "my_tag"}) |
|
``` |
|
|
|
Find all artifact versions in the registries with a collection name that contains "my_collection" and a version that has the alias "best" |
|
```python |
|
api.registries().collections( |
|
filter={"name": {"$regex": "my_collection"}} |
|
).versions(filter={"alias": "best"}) |
|
``` |
|
|
|
Find all artifact versions in the registries that contain "model" and have the tag "prod" or alias "best" |
|
```python |
|
api.registries(filter={"name": {"$regex": "model"}}).versions( |
|
filter={"$or": [{"tag": "prod"}, {"alias": "best"}]} |
|
) |
|
``` |
|
|
|
Args: |
|
organization: (str, optional) The organization of the registry to fetch. |
|
If not specified, use the organization specified in the user's settings. |
|
filter: (dict, optional) MongoDB-style filter to apply to each object in the registry iterator. |
|
Fields available to filter for collections are |
|
`name`, `description`, `created_at`, `updated_at`. |
|
Fields available to filter for collections are |
|
`name`, `tag`, `description`, `created_at`, `updated_at` |
|
Fields available to filter for versions are |
|
`tag`, `alias`, `created_at`, `updated_at`, `metadata` |
|
|
|
Returns: |
|
A registry iterator. |
|
""" |
|
if not InternalApi()._server_supports(ServerFeature.ARTIFACT_REGISTRY_SEARCH): |
|
raise RuntimeError( |
|
"Registry search API is not enabled on this wandb server version. " |
|
"Please upgrade your server version or contact support at support@wandb.com." |
|
) |
|
|
|
organization = organization or fetch_org_from_settings_or_entity( |
|
self.settings, self.default_entity |
|
) |
|
return Registries(self.client, organization, filter) |
|
|
|
def registry(self, name: str, organization: Optional[str] = None) -> Registry: |
|
"""Return a registry given a registry name. |
|
|
|
Args: |
|
name: The name of the registry. This is without the `wandb-registry-` |
|
prefix. |
|
organization: The organization of the registry. |
|
If no organization is set in the settings, the organization will be |
|
fetched from the entity if the entity only belongs to one |
|
organization. |
|
|
|
Returns: |
|
A registry object. |
|
|
|
Examples: |
|
Fetch and update a registry |
|
```python |
|
import wandb |
|
|
|
api = wandb.Api() |
|
registry = api.registry(name="my-registry", organization="my-org") |
|
registry.description = "This is an updated description" |
|
registry.save() |
|
``` |
|
""" |
|
if not InternalApi()._server_supports(ServerFeature.ARTIFACT_REGISTRY_SEARCH): |
|
raise RuntimeError( |
|
"api.registry() is not enabled on this wandb server version. " |
|
"Please upgrade your server version or contact support at support@wandb.com." |
|
) |
|
organization = organization or fetch_org_from_settings_or_entity( |
|
self.settings, self.default_entity |
|
) |
|
org_entity = _fetch_org_entity_from_organization(self.client, organization) |
|
registry = Registry(self.client, organization, org_entity, name) |
|
registry.load() |
|
return registry |
|
|
|
def create_registry( |
|
self, |
|
name: str, |
|
visibility: Literal["organization", "restricted"], |
|
organization: Optional[str] = None, |
|
description: Optional[str] = None, |
|
artifact_types: Optional[List[str]] = None, |
|
) -> Registry: |
|
"""Create a new registry. |
|
|
|
Args: |
|
name: The name of the registry. Name must be unique within the organization. |
|
visibility: The visibility of the registry. |
|
organization: Anyone in the organization can view this registry. You can |
|
edit their roles later from the settings in the UI. |
|
restricted: Only invited members via the UI can access this registry. |
|
Public sharing is disabled. |
|
organization: The organization of the registry. |
|
If no organization is set in the settings, the organization will be |
|
fetched from the entity if the entity only belongs to one organization. |
|
description: The description of the registry. |
|
artifact_types: The accepted artifact types of the registry. A type is no |
|
more than 128 characters and do not include characters `/` or `:`. If |
|
not specified, all types are accepted. |
|
Allowed types added to the registry cannot be removed later. |
|
|
|
Returns: |
|
A registry object. |
|
|
|
Examples: |
|
```python |
|
import wandb |
|
|
|
api = wandb.Api() |
|
registry = api.create_registry( |
|
name="my-registry", |
|
visibility="restricted", |
|
organization="my-org", |
|
description="This is a test registry", |
|
artifact_types=["model"], |
|
) |
|
``` |
|
""" |
|
if not InternalApi()._server_supports( |
|
ServerFeature.INCLUDE_ARTIFACT_TYPES_IN_REGISTRY_CREATION |
|
): |
|
raise RuntimeError( |
|
"create_registry api is not enabled on this wandb server version. " |
|
"Please upgrade your server version or contact support at support@wandb.com." |
|
) |
|
|
|
organization = organization or fetch_org_from_settings_or_entity( |
|
self.settings, self.default_entity |
|
) |
|
|
|
try: |
|
existing_registry = self.registry(name=name, organization=organization) |
|
except ValueError: |
|
existing_registry = None |
|
if existing_registry: |
|
raise ValueError( |
|
f"Registry {name!r} already exists in organization {organization!r}," |
|
" please use a different name." |
|
) |
|
|
|
return Registry.create( |
|
self.client, |
|
organization, |
|
name, |
|
visibility, |
|
description, |
|
artifact_types, |
|
) |
|
|
|
def integrations( |
|
self, |
|
entity: Optional[str] = None, |
|
*, |
|
per_page: int = 50, |
|
) -> Iterator["Integration"]: |
|
"""Return an iterator of all integrations for an entity. |
|
|
|
Args: |
|
entity: The entity (e.g. team name) for which to |
|
fetch integrations. If not provided, the user's default entity |
|
will be used. |
|
per_page: Number of integrations to fetch per page. |
|
Defaults to 50. Usually there is no reason to change this. |
|
|
|
Yields: |
|
Iterator[SlackIntegration | WebhookIntegration]: An iterator of any supported integrations. |
|
""" |
|
from wandb.apis.public.integrations import Integrations |
|
|
|
params = {"entityName": entity or self.default_entity} |
|
return Integrations(client=self.client, variables=params, per_page=per_page) |
|
|
|
def webhook_integrations( |
|
self, entity: Optional[str] = None, *, per_page: int = 50 |
|
) -> Iterator["WebhookIntegration"]: |
|
"""Returns an iterator of webhook integrations for an entity. |
|
|
|
Args: |
|
entity: The entity (e.g. team name) for which to |
|
fetch integrations. If not provided, the user's default entity |
|
will be used. |
|
per_page: Number of integrations to fetch per page. |
|
Defaults to 50. Usually there is no reason to change this. |
|
|
|
Yields: |
|
Iterator[WebhookIntegration]: An iterator of webhook integrations. |
|
|
|
Examples: |
|
Get all registered webhook integrations for the team "my-team": |
|
```python |
|
import wandb |
|
|
|
api = wandb.Api() |
|
webhook_integrations = api.webhook_integrations(entity="my-team") |
|
``` |
|
|
|
Find only webhook integrations that post requests to "https://my-fake-url.com": |
|
```python |
|
webhook_integrations = api.webhook_integrations(entity="my-team") |
|
my_webhooks = [ |
|
ig |
|
for ig in webhook_integrations |
|
if ig.url_endpoint.startswith("https://my-fake-url.com") |
|
] |
|
``` |
|
""" |
|
from wandb.apis.public.integrations import WebhookIntegrations |
|
|
|
params = {"entityName": entity or self.default_entity} |
|
return WebhookIntegrations( |
|
client=self.client, variables=params, per_page=per_page |
|
) |
|
|
|
def slack_integrations( |
|
self, *, entity: Optional[str] = None, per_page: int = 50 |
|
) -> Iterator["SlackIntegration"]: |
|
"""Returns an iterator of Slack integrations for an entity. |
|
|
|
Args: |
|
entity: The entity (e.g. team name) for which to |
|
fetch integrations. If not provided, the user's default entity |
|
will be used. |
|
per_page: Number of integrations to fetch per page. |
|
Defaults to 50. Usually there is no reason to change this. |
|
|
|
Yields: |
|
Iterator[SlackIntegration]: An iterator of Slack integrations. |
|
|
|
Examples: |
|
Get all registered Slack integrations for the team "my-team": |
|
```python |
|
import wandb |
|
|
|
api = wandb.Api() |
|
slack_integrations = api.slack_integrations(entity="my-team") |
|
``` |
|
|
|
Find only Slack integrations that post to channel names starting with "team-alerts-": |
|
```python |
|
slack_integrations = api.slack_integrations(entity="my-team") |
|
team_alert_integrations = [ |
|
ig |
|
for ig in slack_integrations |
|
if ig.channel_name.startswith("team-alerts-") |
|
] |
|
``` |
|
""" |
|
from wandb.apis.public.integrations import SlackIntegrations |
|
|
|
params = {"entityName": entity or self.default_entity} |
|
return SlackIntegrations( |
|
client=self.client, variables=params, per_page=per_page |
|
) |
|
|
|
def _supports_automation( |
|
self, |
|
*, |
|
event: Optional["EventType"] = None, |
|
action: Optional["ActionType"] = None, |
|
) -> bool: |
|
"""Returns whether the server recognizes the automation event and/or action.""" |
|
from wandb.automations._utils import ( |
|
ALWAYS_SUPPORTED_ACTIONS, |
|
ALWAYS_SUPPORTED_EVENTS, |
|
) |
|
|
|
api = InternalApi() |
|
supports_event = ( |
|
(event is None) |
|
or (event in ALWAYS_SUPPORTED_EVENTS) |
|
or api._server_supports(f"AUTOMATION_EVENT_{event.value}") |
|
) |
|
supports_action = ( |
|
(action is None) |
|
or (action in ALWAYS_SUPPORTED_ACTIONS) |
|
or api._server_supports(f"AUTOMATION_ACTION_{action.value}") |
|
) |
|
return supports_event and supports_action |
|
|
|
def _omitted_automation_fragments(self) -> Set[str]: |
|
"""Returns the names of unsupported automation-related fragments. |
|
|
|
Older servers won't recognize newer GraphQL types, so a valid request may |
|
unnecessarily error out because it won't recognize fragments defined on those types. |
|
|
|
So e.g. if a server does not support `NO_OP` action types, then the following need to be |
|
removed from the body of the GraphQL request: |
|
|
|
- Fragment definition: |
|
``` |
|
fragment NoOpActionFields on NoOpTriggeredAction { |
|
noOp |
|
} |
|
``` |
|
|
|
- Fragment spread in selection set: |
|
``` |
|
{ |
|
...NoOpActionFields |
|
# ... other fields ... |
|
} |
|
``` |
|
""" |
|
from wandb.automations import ActionType |
|
from wandb.automations._generated import ( |
|
GenericWebhookActionFields, |
|
NoOpActionFields, |
|
NotificationActionFields, |
|
QueueJobActionFields, |
|
) |
|
|
|
|
|
|
|
fragment_names: dict[ActionType, str] = { |
|
ActionType.NO_OP: NoOpActionFields.__name__, |
|
ActionType.QUEUE_JOB: QueueJobActionFields.__name__, |
|
ActionType.NOTIFICATION: NotificationActionFields.__name__, |
|
ActionType.GENERIC_WEBHOOK: GenericWebhookActionFields.__name__, |
|
} |
|
|
|
return set( |
|
name |
|
for action in ActionType |
|
if (not self._supports_automation(action=action)) |
|
and (name := fragment_names.get(action)) |
|
) |
|
|
|
def automation( |
|
self, |
|
name: str, |
|
*, |
|
entity: Optional[str] = None, |
|
) -> "Automation": |
|
"""Returns the only Automation matching the parameters. |
|
|
|
Args: |
|
name: The name of the automation to fetch. |
|
entity: The entity to fetch the automation for. |
|
|
|
Raises: |
|
ValueError: If zero or multiple Automations match the search criteria. |
|
|
|
Examples: |
|
Get an existing automation named "my-automation": |
|
|
|
```python |
|
import wandb |
|
|
|
api = wandb.Api() |
|
automation = api.automation(name="my-automation") |
|
``` |
|
|
|
Get an existing automation named "other-automation", from the entity "my-team": |
|
|
|
```python |
|
automation = api.automation(name="other-automation", entity="my-team") |
|
``` |
|
""" |
|
return one( |
|
self.automations(entity=entity, name=name), |
|
too_short=ValueError("No automations found"), |
|
too_long=ValueError("Multiple automations found"), |
|
) |
|
|
|
def automations( |
|
self, |
|
entity: Optional[str] = None, |
|
*, |
|
name: Optional[str] = None, |
|
per_page: int = 50, |
|
) -> Iterator["Automation"]: |
|
"""Returns an iterator over all Automations that match the given parameters. |
|
|
|
If no parameters are provided, the returned iterator will contain all |
|
Automations that the user has access to. |
|
|
|
Args: |
|
entity: The entity to fetch the automations for. |
|
name: The name of the automation to fetch. |
|
per_page: The number of automations to fetch per page. |
|
Defaults to 50. Usually there is no reason to change this. |
|
|
|
Returns: |
|
A list of automations. |
|
|
|
Examples: |
|
Fetch all existing automations for the entity "my-team": |
|
|
|
```python |
|
import wandb |
|
|
|
api = wandb.Api() |
|
automations = api.automations(entity="my-team") |
|
``` |
|
""" |
|
from wandb.apis.public.automations import Automations |
|
from wandb.automations._generated import ( |
|
GET_AUTOMATIONS_BY_ENTITY_GQL, |
|
GET_AUTOMATIONS_GQL, |
|
) |
|
|
|
|
|
variables = {"entityName": entity} |
|
if entity is None: |
|
gql_str = GET_AUTOMATIONS_GQL |
|
else: |
|
gql_str = GET_AUTOMATIONS_BY_ENTITY_GQL |
|
|
|
|
|
omit_fragments = self._omitted_automation_fragments() |
|
query = gql_compat(gql_str, omit_fragments=omit_fragments) |
|
iterator = Automations( |
|
client=self.client, variables=variables, per_page=per_page, _query=query |
|
) |
|
|
|
|
|
if name is not None: |
|
iterator = filter(lambda x: x.name == name, iterator) |
|
yield from iterator |
|
|
|
@normalize_exceptions |
|
def create_automation( |
|
self, |
|
obj: "NewAutomation", |
|
*, |
|
fetch_existing: bool = False, |
|
**kwargs: Unpack["WriteAutomationsKwargs"], |
|
) -> "Automation": |
|
"""Create a new Automation. |
|
|
|
Args: |
|
obj: |
|
The automation to create. |
|
fetch_existing: |
|
If True, and a conflicting automation already exists, attempt |
|
to fetch the existing automation instead of raising an error. |
|
**kwargs: |
|
Any additional values to assign to the automation before |
|
creating it. If given, these will override any values that may |
|
already be set on the automation: |
|
- `name`: The name of the automation. |
|
- `description`: The description of the automation. |
|
- `enabled`: Whether the automation is enabled. |
|
- `scope`: The scope of the automation. |
|
- `event`: The event that triggers the automation. |
|
- `action`: The action that is triggered by the automation. |
|
|
|
Returns: |
|
The saved Automation. |
|
|
|
Examples: |
|
Create a new automation named "my-automation" that sends a Slack notification |
|
when a run within a specific project logs a metric exceeding a custom threshold: |
|
|
|
```python |
|
import wandb |
|
from wandb.automations import OnRunMetric, RunEvent, SendNotification |
|
|
|
api = wandb.Api() |
|
|
|
project = api.project("my-project", entity="my-team") |
|
|
|
# Use the first Slack integration for the team |
|
slack_hook = next(api.slack_integrations(entity="my-team")) |
|
|
|
event = OnRunMetric( |
|
scope=project, |
|
filter=RunEvent.metric("custom-metric") > 10, |
|
) |
|
action = SendNotification.from_integration(slack_hook) |
|
|
|
automation = api.create_automation( |
|
event >> action, |
|
name="my-automation", |
|
description="Send a Slack message whenever 'custom-metric' exceeds 10.", |
|
) |
|
``` |
|
""" |
|
from wandb.automations import Automation |
|
from wandb.automations._generated import CREATE_AUTOMATION_GQL, CreateAutomation |
|
from wandb.automations._utils import prepare_to_create |
|
|
|
gql_input = prepare_to_create(obj, **kwargs) |
|
|
|
if not self._supports_automation( |
|
event=(event := gql_input.triggering_event_type), |
|
action=(action := gql_input.triggered_action_type), |
|
): |
|
raise ValueError( |
|
f"Automation event or action ({event!r} -> {action!r}) " |
|
"is not supported on this wandb server version. " |
|
"Please upgrade your server version, or contact support at " |
|
"support@wandb.com." |
|
) |
|
|
|
|
|
omit_fragments = self._omitted_automation_fragments() |
|
mutation = gql_compat(CREATE_AUTOMATION_GQL, omit_fragments=omit_fragments) |
|
variables = {"params": gql_input.model_dump(exclude_none=True)} |
|
|
|
name = gql_input.name |
|
try: |
|
data = self.client.execute(mutation, variable_values=variables) |
|
except requests.HTTPError as e: |
|
status = HTTPStatus(e.response.status_code) |
|
if status is HTTPStatus.CONFLICT: |
|
if fetch_existing: |
|
wandb.termlog(f"Automation {name!r} exists. Fetching it instead.") |
|
return self.automation(name=name) |
|
|
|
raise ValueError( |
|
f"Automation {name!r} exists. Unable to create another with the same name." |
|
) from None |
|
raise |
|
|
|
try: |
|
result = CreateAutomation.model_validate(data).result |
|
except ValidationError as e: |
|
msg = f"Invalid response while creating automation {name!r}" |
|
raise RuntimeError(msg) from e |
|
|
|
if (result is None) or (result.trigger is None): |
|
msg = f"Empty response while creating automation {name!r}" |
|
raise RuntimeError(msg) |
|
|
|
return Automation.model_validate(result.trigger) |
|
|
|
@normalize_exceptions |
|
def update_automation( |
|
self, |
|
obj: "Automation", |
|
*, |
|
create_missing: bool = False, |
|
**kwargs: Unpack["WriteAutomationsKwargs"], |
|
) -> "Automation": |
|
"""Update an existing automation. |
|
|
|
Args: |
|
obj: The automation to update. Must be an existing automation. |
|
create_missing (bool): |
|
If True, and the automation does not exist, create it. |
|
**kwargs: |
|
Any additional values to assign to the automation before |
|
updating it. If given, these will override any values that may |
|
already be set on the automation: |
|
- `name`: The name of the automation. |
|
- `description`: The description of the automation. |
|
- `enabled`: Whether the automation is enabled. |
|
- `scope`: The scope of the automation. |
|
- `event`: The event that triggers the automation. |
|
- `action`: The action that is triggered by the automation. |
|
|
|
Returns: |
|
The updated automation. |
|
|
|
Examples: |
|
Disable and edit the description of an existing automation ("my-automation"): |
|
|
|
```python |
|
import wandb |
|
|
|
api = wandb.Api() |
|
|
|
automation = api.automation(name="my-automation") |
|
automation.enabled = False |
|
automation.description = "Kept for reference, but no longer used." |
|
|
|
updated_automation = api.update_automation(automation) |
|
``` |
|
|
|
OR: |
|
|
|
```python |
|
import wandb |
|
|
|
api = wandb.Api() |
|
|
|
automation = api.automation(name="my-automation") |
|
|
|
updated_automation = api.update_automation( |
|
automation, |
|
enabled=False, |
|
description="Kept for reference, but no longer used.", |
|
) |
|
``` |
|
""" |
|
from wandb.automations import ActionType, Automation |
|
from wandb.automations._generated import UPDATE_AUTOMATION_GQL, UpdateAutomation |
|
from wandb.automations._utils import prepare_to_update |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self._supports_automation(action=ActionType.NO_OP): |
|
raise RuntimeError( |
|
"Updating existing automations is not enabled on this wandb server version. " |
|
"Please upgrade your server version, or contact support at support@wandb.com." |
|
) |
|
|
|
gql_input = prepare_to_update(obj, **kwargs) |
|
|
|
if not self._supports_automation( |
|
event=(event := gql_input.triggering_event_type), |
|
action=(action := gql_input.triggered_action_type), |
|
): |
|
raise ValueError( |
|
f"Automation event or action ({event.value} -> {action.value}) " |
|
"is not supported on this wandb server version. " |
|
"Please upgrade your server version, or contact support at " |
|
"support@wandb.com." |
|
) |
|
|
|
|
|
omit_fragments = self._omitted_automation_fragments() |
|
mutation = gql_compat(UPDATE_AUTOMATION_GQL, omit_fragments=omit_fragments) |
|
variables = {"params": gql_input.model_dump(exclude_none=True)} |
|
|
|
name = gql_input.name |
|
try: |
|
data = self.client.execute(mutation, variable_values=variables) |
|
except requests.HTTPError as e: |
|
status = HTTPStatus(e.response.status_code) |
|
if status is HTTPStatus.NOT_FOUND: |
|
if create_missing: |
|
wandb.termlog(f"Automation {name!r} not found. Creating it.") |
|
return self.create_automation(obj) |
|
|
|
raise ValueError( |
|
f"Automation {name!r} not found. Unable to edit it." |
|
) from e |
|
|
|
|
|
wandb.termerror(f"Got response status {status!r}: {e.response.text!r}") |
|
raise |
|
|
|
try: |
|
result = UpdateAutomation.model_validate(data).result |
|
except ValidationError as e: |
|
msg = f"Invalid response while updating automation {name!r}" |
|
raise RuntimeError(msg) from e |
|
|
|
if (result is None) or (result.trigger is None): |
|
msg = f"Empty response while updating automation {name!r}" |
|
raise RuntimeError(msg) |
|
|
|
return Automation.model_validate(result.trigger) |
|
|
|
@normalize_exceptions |
|
def delete_automation(self, obj: Union["Automation", str]) -> Literal[True]: |
|
"""Delete an automation. |
|
|
|
Args: |
|
obj: The automation to delete, or its ID. |
|
|
|
Returns: |
|
True if the automation was deleted successfully. |
|
""" |
|
from wandb.automations._generated import DELETE_AUTOMATION_GQL, DeleteAutomation |
|
from wandb.automations._utils import extract_id |
|
|
|
id_ = extract_id(obj) |
|
mutation = gql(DELETE_AUTOMATION_GQL) |
|
variables = {"id": id_} |
|
|
|
data = self.client.execute(mutation, variable_values=variables) |
|
|
|
try: |
|
result = DeleteAutomation.model_validate(data).result |
|
except ValidationError as e: |
|
msg = f"Invalid response while deleting automation {id_!r}" |
|
raise RuntimeError(msg) from e |
|
|
|
if result is None: |
|
msg = f"Empty response while deleting automation {id_!r}" |
|
raise RuntimeError(msg) |
|
|
|
if not result.success: |
|
raise RuntimeError(f"Failed to delete automation: {id_!r}") |
|
|
|
return result.success |
|
|