import json import os import re from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import yaml import wandb from wandb import util from wandb.sdk.launch.errors import LaunchError if TYPE_CHECKING: from wandb.apis.public import Api as PublicApi DEFAULT_SWEEP_COMMAND: List[str] = [ "${env}", "${interpreter}", "${program}", "${args}", ] SWEEP_COMMAND_ENV_VAR_REGEX = re.compile(r"\$\{envvar\:([A-Z0-9_]*)\}") def parse_sweep_id(parts_dict: dict) -> Optional[str]: """In place parse sweep path from parts dict. Arguments: parts_dict (dict): dict(entity=,project=,name=). Modifies dict inplace. Returns: None or str if there is an error """ entity = None project = None sweep_id = parts_dict.get("name") if not isinstance(sweep_id, str): return "Expected string sweep_id" sweep_split = sweep_id.split("/") if len(sweep_split) == 1: pass elif len(sweep_split) == 2: split_project, sweep_id = sweep_split project = split_project or project elif len(sweep_split) == 3: split_entity, split_project, sweep_id = sweep_split project = split_project or project entity = split_entity or entity else: return ( "Expected sweep_id in form of sweep, project/sweep, or entity/project/sweep" ) parts_dict.update(dict(name=sweep_id, project=project, entity=entity)) return None def sweep_config_err_text_from_jsonschema_violations(violations: List[str]) -> str: """Consolidate schema violation strings from wandb/sweeps into a single string. Parameters ---------- violations: list of str The warnings to render. Returns: ------- violation: str The consolidated violation text. """ violation_base = ( "Malformed sweep config detected! This may cause your sweep to behave in unexpected ways.\n" "To avoid this, please fix the sweep config schema violations below:" ) for i, warning in enumerate(violations): violations[i] = f" Violation {i + 1}. {warning}" violation = "\n".join([violation_base] + violations) return violation def handle_sweep_config_violations(warnings: List[str]) -> None: """Echo sweep config schema violation warnings from Gorilla to the terminal. Parameters ---------- warnings: list of str The warnings to render. """ warning = sweep_config_err_text_from_jsonschema_violations(warnings) if len(warnings) > 0: wandb.termwarn(warning) def load_sweep_config(sweep_config_path: str) -> Optional[Dict[str, Any]]: """Load a sweep yaml from path.""" try: yaml_file = open(sweep_config_path) except OSError: wandb.termerror(f"Couldn't open sweep file: {sweep_config_path}") return None try: config: Optional[Dict[str, Any]] = yaml.safe_load(yaml_file) except yaml.YAMLError as err: wandb.termerror(f"Error in configuration file: {err}") return None if not config: wandb.termerror("Configuration file is empty") return None return config def load_launch_sweep_config(config: Optional[str]) -> Any: if not config: return {} parsed_config = util.load_json_yaml_dict(config) if parsed_config is None: raise LaunchError(f"Could not load config from {config}. Check formatting") return parsed_config def construct_scheduler_args( sweep_config: Dict[str, Any], queue: str, project: str, author: Optional[str] = None, return_job: bool = False, ) -> Union[List[str], Dict[str, str], None]: """Construct sweep scheduler args. logs error and returns None if misconfigured, otherwise returns args as a dict if is_job else a list of strings. """ job = sweep_config.get("job") image_uri = sweep_config.get("image_uri") if not job and not image_uri: # don't allow empty string wandb.termerror( "No 'job' nor 'image_uri' top-level key found in sweep config, exactly one is required for a launch-sweep" ) return None elif job and image_uri: wandb.termerror( "Sweep config has both 'job' and 'image_uri' but a launch-sweep can use only one" ) return None # if scheduler is a job, return args as dict if return_job: args_dict: Dict[str, str] = { "sweep_id": "WANDB_SWEEP_ID", "queue": queue, "project": project, } if job: args_dict["job"] = job elif image_uri: args_dict["image_uri"] = image_uri if author: args_dict["author"] = author return args_dict # scheduler uses cli commands, pass args as param list args = [ "--queue", f"{queue!r}", "--project", f"{project!r}", ] if author: args += [ "--author", f"{author!r}", ] if job: args += [ "--job", f"{job!r}", ] elif image_uri: args += ["--image_uri", image_uri] return args def create_sweep_command(command: Optional[List] = None) -> List: """Return sweep command, filling in environment variable macros.""" # Start from default sweep command command = command or DEFAULT_SWEEP_COMMAND for i, chunk in enumerate(command): # Replace environment variable macros # Search a str(chunk), but allow matches to be of any (ex: int) type if SWEEP_COMMAND_ENV_VAR_REGEX.search(str(chunk)): # Replace from backwards forwards matches = list(SWEEP_COMMAND_ENV_VAR_REGEX.finditer(chunk)) for m in matches[::-1]: # Default to just leaving as is if environment variable does not exist _var: str = os.environ.get(m.group(1), m.group(1)) command[i] = f"{command[i][: m.start()]}{_var}{command[i][m.end() :]}" return command def create_sweep_command_args(command: Dict) -> Dict[str, Any]: """Create various formats of command arguments for the agent. Raises: ValueError: improperly formatted command dict """ if "args" not in command: raise ValueError(f'No "args" found in command: {command}') # four different formats of command args # (1) standard command line flags (e.g. --foo=bar) flags: List[str] = [] # (2) flags without hyphens (e.g. foo=bar) flags_no_hyphens: List[str] = [] # (3) flags with false booleans omitted (e.g. --foo) flags_no_booleans: List[str] = [] # (4) flags as a dictionary (used for constructing a json) flags_dict: Dict[str, Any] = {} # (5) flags without equals (e.g. --foo bar) args_no_equals: List[str] = [] # (6) flags for hydra append config value (e.g. +foo=bar) flags_append_hydra: List[str] = [] # (7) flags for hydra override config value (e.g. ++foo=bar) flags_override_hydra: List[str] = [] for param, config in command["args"].items(): # allow 'None' as a valid value, but error if no value is found try: _value: Any = config["value"] except KeyError: raise ValueError(f'No "value" found for command["args"]["{param}"]') _flag: str = f"{param}={_value}" flags.append("--" + _flag) flags_no_hyphens.append(_flag) args_no_equals += [f"--{param}", str(_value)] flags_append_hydra.append("+" + _flag) flags_override_hydra.append("++" + _flag) if isinstance(_value, bool): # omit flags if they are boolean and false if _value: flags_no_booleans.append("--" + param) else: flags_no_booleans.append("--" + _flag) flags_dict[param] = _value return { "args": flags, "args_no_equals": args_no_equals, "args_no_hyphens": flags_no_hyphens, "args_no_boolean_flags": flags_no_booleans, "args_json": [json.dumps(flags_dict)], "args_dict": flags_dict, "args_append_hydra": flags_append_hydra, "args_override_hydra": flags_override_hydra, } def make_launch_sweep_entrypoint( args: Dict[str, Any], command: Optional[List[str]] ) -> Tuple[Optional[List[str]], Any]: """Use args dict from create_sweep_command_args to construct entrypoint. If replace is True, remove macros from entrypoint, fill them in with args and then return the args in separate return value. """ if not command: return None, None entry_point = create_sweep_command(command) macro_args = {} for macro in args: mstr = "${" + macro + "}" if mstr in entry_point: idx = entry_point.index(mstr) # only supports 1 macro per entrypoint macro_args = args[macro] entry_point = entry_point[:idx] + entry_point[idx + 1 :] if len(entry_point) == 0: return None, macro_args return entry_point, macro_args def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool: """Check if the job exists using the public api. Returns: True if no job is passed, or if the job exists. Returns: False if the job is misformatted or doesn't exist. """ if not job: return True try: public_api.job(job) except Exception as e: wandb.termerror(f"Failed to load job. {e}") return False return True def get_previous_args( run_spec: Dict[str, Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Parse through previous scheduler run_spec. returns scheduler_args and settings. """ scheduler_args = ( run_spec.get("overrides", {}).get("run_config", {}).get("scheduler", {}) ) # also pipe through top level resource setup if run_spec.get("resource"): scheduler_args["resource"] = run_spec["resource"] if run_spec.get("resource_args"): scheduler_args["resource_args"] = run_spec["resource_args"] settings = run_spec.get("overrides", {}).get("run_config", {}).get("settings", {}) return scheduler_args, settings