|
"""Public API: sweeps.""" |
|
|
|
import urllib |
|
from typing import Optional |
|
|
|
from wandb_gql import gql |
|
|
|
import wandb |
|
from wandb import util |
|
from wandb.apis import public |
|
from wandb.apis.attrs import Attrs |
|
from wandb.sdk.lib import ipython |
|
|
|
SWEEP_FRAGMENT = """fragment SweepFragment on Sweep { |
|
id |
|
name |
|
method |
|
state |
|
description |
|
displayName |
|
bestLoss |
|
config |
|
createdAt |
|
updatedAt |
|
runCount |
|
} |
|
""" |
|
|
|
|
|
class Sweep(Attrs): |
|
"""A set of runs associated with a sweep. |
|
|
|
Examples: |
|
Instantiate with: |
|
``` |
|
api = wandb.Api() |
|
sweep = api.sweep(path / to / sweep) |
|
``` |
|
|
|
Attributes: |
|
runs: (`Runs`) list of runs |
|
id: (str) sweep id |
|
project: (str) name of project |
|
config: (str) dictionary of sweep configuration |
|
state: (str) the state of the sweep |
|
expected_run_count: (int) number of expected runs for the sweep |
|
""" |
|
|
|
QUERY = gql( |
|
""" |
|
query Sweep($project: String, $entity: String, $name: String!) { |
|
project(name: $project, entityName: $entity) { |
|
sweep(sweepName: $name) { |
|
id |
|
name |
|
state |
|
runCountExpected |
|
bestLoss |
|
config |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
|
|
LEGACY_QUERY = gql( |
|
""" |
|
query Sweep($project: String, $entity: String, $name: String!) { |
|
project(name: $project, entityName: $entity) { |
|
sweep(sweepName: $name) { |
|
id |
|
name |
|
state |
|
bestLoss |
|
config |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
|
|
def __init__(self, client, entity, project, sweep_id, attrs=None): |
|
|
|
super().__init__(dict(attrs or {})) |
|
self.client = client |
|
self._entity = entity |
|
self.project = project |
|
self.id = sweep_id |
|
self.runs = [] |
|
|
|
self.load(force=not attrs) |
|
|
|
@property |
|
def entity(self): |
|
return self._entity |
|
|
|
@property |
|
def username(self): |
|
wandb.termwarn("Sweep.username is deprecated. please use Sweep.entity instead.") |
|
return self._entity |
|
|
|
@property |
|
def config(self): |
|
return util.load_yaml(self._attrs["config"]) |
|
|
|
def load(self, force: bool = False): |
|
if force or not self._attrs: |
|
sweep = self.get(self.client, self.entity, self.project, self.id) |
|
if sweep is None: |
|
raise ValueError("Could not find sweep {}".format(self)) |
|
self._attrs = sweep._attrs |
|
self.runs = sweep.runs |
|
|
|
return self._attrs |
|
|
|
@property |
|
def order(self): |
|
if self._attrs.get("config") and self.config.get("metric"): |
|
sort_order = self.config["metric"].get("goal", "minimize") |
|
prefix = "+" if sort_order == "minimize" else "-" |
|
return public.QueryGenerator.format_order_key( |
|
prefix + self.config["metric"]["name"] |
|
) |
|
|
|
def best_run(self, order=None): |
|
"""Return the best run sorted by the metric defined in config or the order passed in.""" |
|
if order is None: |
|
order = self.order |
|
else: |
|
order = public.QueryGenerator.format_order_key(order) |
|
if order is None: |
|
wandb.termwarn( |
|
"No order specified and couldn't find metric in sweep config, returning most recent run" |
|
) |
|
else: |
|
wandb.termlog("Sorting runs by {}".format(order)) |
|
filters = {"$and": [{"sweep": self.id}]} |
|
try: |
|
return public.Runs( |
|
self.client, |
|
self.entity, |
|
self.project, |
|
order=order, |
|
filters=filters, |
|
per_page=1, |
|
)[0] |
|
except IndexError: |
|
return None |
|
|
|
@property |
|
def expected_run_count(self) -> Optional[int]: |
|
"""Return the number of expected runs in the sweep or None for infinite runs.""" |
|
return self._attrs.get("runCountExpected") |
|
|
|
@property |
|
def path(self): |
|
return [ |
|
urllib.parse.quote_plus(str(self.entity)), |
|
urllib.parse.quote_plus(str(self.project)), |
|
urllib.parse.quote_plus(str(self.id)), |
|
] |
|
|
|
@property |
|
def url(self): |
|
path = self.path |
|
path.insert(2, "sweeps") |
|
return self.client.app_url + "/".join(path) |
|
|
|
@property |
|
def name(self): |
|
return self.config.get("name") or self.id |
|
|
|
@classmethod |
|
def get( |
|
cls, |
|
client, |
|
entity=None, |
|
project=None, |
|
sid=None, |
|
order=None, |
|
query=None, |
|
**kwargs, |
|
): |
|
"""Execute a query against the cloud backend.""" |
|
if query is None: |
|
query = cls.QUERY |
|
|
|
variables = { |
|
"entity": entity, |
|
"project": project, |
|
"name": sid, |
|
} |
|
variables.update(kwargs) |
|
|
|
response = None |
|
try: |
|
response = client.execute(query, variable_values=variables) |
|
except Exception: |
|
|
|
|
|
query = cls.LEGACY_QUERY |
|
response = client.execute(query, variable_values=variables) |
|
|
|
if ( |
|
not response |
|
or not response.get("project") |
|
or not response["project"].get("sweep") |
|
): |
|
return None |
|
|
|
sweep_response = response["project"]["sweep"] |
|
sweep = cls(client, entity, project, sid, attrs=sweep_response) |
|
sweep.runs = public.Runs( |
|
client, |
|
entity, |
|
project, |
|
order=order, |
|
per_page=10, |
|
filters={"$and": [{"sweep": sweep.id}]}, |
|
) |
|
|
|
return sweep |
|
|
|
def to_html(self, height=420, hidden=False): |
|
"""Generate HTML containing an iframe displaying this sweep.""" |
|
url = self.url + "?jupyter=true" |
|
style = f"border:none;width:100%;height:{height}px;" |
|
prefix = "" |
|
if hidden: |
|
style += "display:none;" |
|
prefix = ipython.toggle_button("sweep") |
|
return prefix + f"<iframe src={url!r} style={style!r}></iframe>" |
|
|
|
def _repr_html_(self) -> str: |
|
return self.to_html() |
|
|
|
def __repr__(self): |
|
return "<Sweep {} ({})>".format( |
|
"/".join(self.path), self._attrs.get("state", "Unknown State") |
|
) |
|
|