File size: 10,229 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
import json
import os
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import yaml
import wandb
from wandb import util
from wandb.sdk.launch.errors import LaunchError
if TYPE_CHECKING:
from wandb.apis.public import Api as PublicApi
DEFAULT_SWEEP_COMMAND: List[str] = [
"${env}",
"${interpreter}",
"${program}",
"${args}",
]
SWEEP_COMMAND_ENV_VAR_REGEX = re.compile(r"\$\{envvar\:([A-Z0-9_]*)\}")
def parse_sweep_id(parts_dict: dict) -> Optional[str]:
"""In place parse sweep path from parts dict.
Arguments:
parts_dict (dict): dict(entity=,project=,name=). Modifies dict inplace.
Returns:
None or str if there is an error
"""
entity = None
project = None
sweep_id = parts_dict.get("name")
if not isinstance(sweep_id, str):
return "Expected string sweep_id"
sweep_split = sweep_id.split("/")
if len(sweep_split) == 1:
pass
elif len(sweep_split) == 2:
split_project, sweep_id = sweep_split
project = split_project or project
elif len(sweep_split) == 3:
split_entity, split_project, sweep_id = sweep_split
project = split_project or project
entity = split_entity or entity
else:
return (
"Expected sweep_id in form of sweep, project/sweep, or entity/project/sweep"
)
parts_dict.update(dict(name=sweep_id, project=project, entity=entity))
return None
def sweep_config_err_text_from_jsonschema_violations(violations: List[str]) -> str:
"""Consolidate schema violation strings from wandb/sweeps into a single string.
Parameters
----------
violations: list of str
The warnings to render.
Returns:
-------
violation: str
The consolidated violation text.
"""
violation_base = (
"Malformed sweep config detected! This may cause your sweep to behave in unexpected ways.\n"
"To avoid this, please fix the sweep config schema violations below:"
)
for i, warning in enumerate(violations):
violations[i] = f" Violation {i + 1}. {warning}"
violation = "\n".join([violation_base] + violations)
return violation
def handle_sweep_config_violations(warnings: List[str]) -> None:
"""Echo sweep config schema violation warnings from Gorilla to the terminal.
Parameters
----------
warnings: list of str
The warnings to render.
"""
warning = sweep_config_err_text_from_jsonschema_violations(warnings)
if len(warnings) > 0:
wandb.termwarn(warning)
def load_sweep_config(sweep_config_path: str) -> Optional[Dict[str, Any]]:
"""Load a sweep yaml from path."""
try:
yaml_file = open(sweep_config_path)
except OSError:
wandb.termerror(f"Couldn't open sweep file: {sweep_config_path}")
return None
try:
config: Optional[Dict[str, Any]] = yaml.safe_load(yaml_file)
except yaml.YAMLError as err:
wandb.termerror(f"Error in configuration file: {err}")
return None
if not config:
wandb.termerror("Configuration file is empty")
return None
return config
def load_launch_sweep_config(config: Optional[str]) -> Any:
if not config:
return {}
parsed_config = util.load_json_yaml_dict(config)
if parsed_config is None:
raise LaunchError(f"Could not load config from {config}. Check formatting")
return parsed_config
def construct_scheduler_args(
sweep_config: Dict[str, Any],
queue: str,
project: str,
author: Optional[str] = None,
return_job: bool = False,
) -> Union[List[str], Dict[str, str], None]:
"""Construct sweep scheduler args.
logs error and returns None if misconfigured,
otherwise returns args as a dict if is_job else a list of strings.
"""
job = sweep_config.get("job")
image_uri = sweep_config.get("image_uri")
if not job and not image_uri: # don't allow empty string
wandb.termerror(
"No 'job' nor 'image_uri' top-level key found in sweep config, exactly one is required for a launch-sweep"
)
return None
elif job and image_uri:
wandb.termerror(
"Sweep config has both 'job' and 'image_uri' but a launch-sweep can use only one"
)
return None
# if scheduler is a job, return args as dict
if return_job:
args_dict: Dict[str, str] = {
"sweep_id": "WANDB_SWEEP_ID",
"queue": queue,
"project": project,
}
if job:
args_dict["job"] = job
elif image_uri:
args_dict["image_uri"] = image_uri
if author:
args_dict["author"] = author
return args_dict
# scheduler uses cli commands, pass args as param list
args = [
"--queue",
f"{queue!r}",
"--project",
f"{project!r}",
]
if author:
args += [
"--author",
f"{author!r}",
]
if job:
args += [
"--job",
f"{job!r}",
]
elif image_uri:
args += ["--image_uri", image_uri]
return args
def create_sweep_command(command: Optional[List] = None) -> List:
"""Return sweep command, filling in environment variable macros."""
# Start from default sweep command
command = command or DEFAULT_SWEEP_COMMAND
for i, chunk in enumerate(command):
# Replace environment variable macros
# Search a str(chunk), but allow matches to be of any (ex: int) type
if SWEEP_COMMAND_ENV_VAR_REGEX.search(str(chunk)):
# Replace from backwards forwards
matches = list(SWEEP_COMMAND_ENV_VAR_REGEX.finditer(chunk))
for m in matches[::-1]:
# Default to just leaving as is if environment variable does not exist
_var: str = os.environ.get(m.group(1), m.group(1))
command[i] = f"{command[i][: m.start()]}{_var}{command[i][m.end() :]}"
return command
def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
"""Create various formats of command arguments for the agent.
Raises:
ValueError: improperly formatted command dict
"""
if "args" not in command:
raise ValueError(f'No "args" found in command: {command}')
# four different formats of command args
# (1) standard command line flags (e.g. --foo=bar)
flags: List[str] = []
# (2) flags without hyphens (e.g. foo=bar)
flags_no_hyphens: List[str] = []
# (3) flags with false booleans omitted (e.g. --foo)
flags_no_booleans: List[str] = []
# (4) flags as a dictionary (used for constructing a json)
flags_dict: Dict[str, Any] = {}
# (5) flags without equals (e.g. --foo bar)
args_no_equals: List[str] = []
# (6) flags for hydra append config value (e.g. +foo=bar)
flags_append_hydra: List[str] = []
# (7) flags for hydra override config value (e.g. ++foo=bar)
flags_override_hydra: List[str] = []
for param, config in command["args"].items():
# allow 'None' as a valid value, but error if no value is found
try:
_value: Any = config["value"]
except KeyError:
raise ValueError(f'No "value" found for command["args"]["{param}"]')
_flag: str = f"{param}={_value}"
flags.append("--" + _flag)
flags_no_hyphens.append(_flag)
args_no_equals += [f"--{param}", str(_value)]
flags_append_hydra.append("+" + _flag)
flags_override_hydra.append("++" + _flag)
if isinstance(_value, bool):
# omit flags if they are boolean and false
if _value:
flags_no_booleans.append("--" + param)
else:
flags_no_booleans.append("--" + _flag)
flags_dict[param] = _value
return {
"args": flags,
"args_no_equals": args_no_equals,
"args_no_hyphens": flags_no_hyphens,
"args_no_boolean_flags": flags_no_booleans,
"args_json": [json.dumps(flags_dict)],
"args_dict": flags_dict,
"args_append_hydra": flags_append_hydra,
"args_override_hydra": flags_override_hydra,
}
def make_launch_sweep_entrypoint(
args: Dict[str, Any], command: Optional[List[str]]
) -> Tuple[Optional[List[str]], Any]:
"""Use args dict from create_sweep_command_args to construct entrypoint.
If replace is True, remove macros from entrypoint, fill them in with args
and then return the args in separate return value.
"""
if not command:
return None, None
entry_point = create_sweep_command(command)
macro_args = {}
for macro in args:
mstr = "${" + macro + "}"
if mstr in entry_point:
idx = entry_point.index(mstr)
# only supports 1 macro per entrypoint
macro_args = args[macro]
entry_point = entry_point[:idx] + entry_point[idx + 1 :]
if len(entry_point) == 0:
return None, macro_args
return entry_point, macro_args
def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
"""Check if the job exists using the public api.
Returns: True if no job is passed, or if the job exists.
Returns: False if the job is misformatted or doesn't exist.
"""
if not job:
return True
try:
public_api.job(job)
except Exception as e:
wandb.termerror(f"Failed to load job. {e}")
return False
return True
def get_previous_args(
run_spec: Dict[str, Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Parse through previous scheduler run_spec.
returns scheduler_args and settings.
"""
scheduler_args = (
run_spec.get("overrides", {}).get("run_config", {}).get("scheduler", {})
)
# also pipe through top level resource setup
if run_spec.get("resource"):
scheduler_args["resource"] = run_spec["resource"]
if run_spec.get("resource_args"):
scheduler_args["resource_args"] = run_spec["resource_args"]
settings = run_spec.get("overrides", {}).get("run_config", {}).get("settings", {})
return scheduler_args, settings
|