File size: 10,229 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import json
import os
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import yaml

import wandb
from wandb import util
from wandb.sdk.launch.errors import LaunchError

if TYPE_CHECKING:
    from wandb.apis.public import Api as PublicApi

DEFAULT_SWEEP_COMMAND: List[str] = [
    "${env}",
    "${interpreter}",
    "${program}",
    "${args}",
]
SWEEP_COMMAND_ENV_VAR_REGEX = re.compile(r"\$\{envvar\:([A-Z0-9_]*)\}")


def parse_sweep_id(parts_dict: dict) -> Optional[str]:
    """In place parse sweep path from parts dict.

    Arguments:
        parts_dict (dict): dict(entity=,project=,name=).  Modifies dict inplace.

    Returns:
        None or str if there is an error
    """
    entity = None
    project = None
    sweep_id = parts_dict.get("name")
    if not isinstance(sweep_id, str):
        return "Expected string sweep_id"

    sweep_split = sweep_id.split("/")
    if len(sweep_split) == 1:
        pass
    elif len(sweep_split) == 2:
        split_project, sweep_id = sweep_split
        project = split_project or project
    elif len(sweep_split) == 3:
        split_entity, split_project, sweep_id = sweep_split
        project = split_project or project
        entity = split_entity or entity
    else:
        return (
            "Expected sweep_id in form of sweep, project/sweep, or entity/project/sweep"
        )
    parts_dict.update(dict(name=sweep_id, project=project, entity=entity))
    return None


def sweep_config_err_text_from_jsonschema_violations(violations: List[str]) -> str:
    """Consolidate schema violation strings from wandb/sweeps into a single string.

    Parameters
    ----------
    violations: list of str
        The warnings to render.

    Returns:
    -------
    violation: str
        The consolidated violation text.

    """
    violation_base = (
        "Malformed sweep config detected! This may cause your sweep to behave in unexpected ways.\n"
        "To avoid this, please fix the sweep config schema violations below:"
    )

    for i, warning in enumerate(violations):
        violations[i] = f"  Violation {i + 1}. {warning}"
    violation = "\n".join([violation_base] + violations)

    return violation


def handle_sweep_config_violations(warnings: List[str]) -> None:
    """Echo sweep config schema violation warnings from Gorilla to the terminal.

    Parameters
    ----------
    warnings: list of str
        The warnings to render.
    """
    warning = sweep_config_err_text_from_jsonschema_violations(warnings)
    if len(warnings) > 0:
        wandb.termwarn(warning)


def load_sweep_config(sweep_config_path: str) -> Optional[Dict[str, Any]]:
    """Load a sweep yaml from path."""
    try:
        yaml_file = open(sweep_config_path)
    except OSError:
        wandb.termerror(f"Couldn't open sweep file: {sweep_config_path}")
        return None
    try:
        config: Optional[Dict[str, Any]] = yaml.safe_load(yaml_file)
    except yaml.YAMLError as err:
        wandb.termerror(f"Error in configuration file: {err}")
        return None
    if not config:
        wandb.termerror("Configuration file is empty")
        return None
    return config


def load_launch_sweep_config(config: Optional[str]) -> Any:
    if not config:
        return {}

    parsed_config = util.load_json_yaml_dict(config)
    if parsed_config is None:
        raise LaunchError(f"Could not load config from {config}. Check formatting")
    return parsed_config


def construct_scheduler_args(
    sweep_config: Dict[str, Any],
    queue: str,
    project: str,
    author: Optional[str] = None,
    return_job: bool = False,
) -> Union[List[str], Dict[str, str], None]:
    """Construct sweep scheduler args.

    logs error and returns None if misconfigured,
    otherwise returns args as a dict if is_job else a list of strings.
    """
    job = sweep_config.get("job")
    image_uri = sweep_config.get("image_uri")
    if not job and not image_uri:  # don't allow empty string
        wandb.termerror(
            "No 'job' nor 'image_uri' top-level key found in sweep config, exactly one is required for a launch-sweep"
        )
        return None
    elif job and image_uri:
        wandb.termerror(
            "Sweep config has both 'job' and 'image_uri' but a launch-sweep can use only one"
        )
        return None

    # if scheduler is a job, return args as dict
    if return_job:
        args_dict: Dict[str, str] = {
            "sweep_id": "WANDB_SWEEP_ID",
            "queue": queue,
            "project": project,
        }
        if job:
            args_dict["job"] = job
        elif image_uri:
            args_dict["image_uri"] = image_uri

        if author:
            args_dict["author"] = author

        return args_dict

    # scheduler uses cli commands, pass args as param list
    args = [
        "--queue",
        f"{queue!r}",
        "--project",
        f"{project!r}",
    ]
    if author:
        args += [
            "--author",
            f"{author!r}",
        ]
    if job:
        args += [
            "--job",
            f"{job!r}",
        ]
    elif image_uri:
        args += ["--image_uri", image_uri]

    return args


def create_sweep_command(command: Optional[List] = None) -> List:
    """Return sweep command, filling in environment variable macros."""
    # Start from default sweep command
    command = command or DEFAULT_SWEEP_COMMAND
    for i, chunk in enumerate(command):
        # Replace environment variable macros
        # Search a str(chunk), but allow matches to be of any (ex: int) type
        if SWEEP_COMMAND_ENV_VAR_REGEX.search(str(chunk)):
            # Replace from backwards forwards
            matches = list(SWEEP_COMMAND_ENV_VAR_REGEX.finditer(chunk))
            for m in matches[::-1]:
                # Default to just leaving as is if environment variable does not exist
                _var: str = os.environ.get(m.group(1), m.group(1))
                command[i] = f"{command[i][: m.start()]}{_var}{command[i][m.end() :]}"
    return command


def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
    """Create various formats of command arguments for the agent.

    Raises:
        ValueError: improperly formatted command dict

    """
    if "args" not in command:
        raise ValueError(f'No "args" found in command: {command}')
    # four different formats of command args
    # (1) standard command line flags (e.g. --foo=bar)
    flags: List[str] = []
    # (2) flags without hyphens (e.g. foo=bar)
    flags_no_hyphens: List[str] = []
    # (3) flags with false booleans omitted  (e.g. --foo)
    flags_no_booleans: List[str] = []
    # (4) flags as a dictionary (used for constructing a json)
    flags_dict: Dict[str, Any] = {}
    # (5) flags without equals (e.g. --foo bar)
    args_no_equals: List[str] = []
    # (6) flags for hydra append config value (e.g. +foo=bar)
    flags_append_hydra: List[str] = []
    # (7) flags for hydra override config value (e.g. ++foo=bar)
    flags_override_hydra: List[str] = []
    for param, config in command["args"].items():
        # allow 'None' as a valid value, but error if no value is found
        try:
            _value: Any = config["value"]
        except KeyError:
            raise ValueError(f'No "value" found for command["args"]["{param}"]')

        _flag: str = f"{param}={_value}"
        flags.append("--" + _flag)
        flags_no_hyphens.append(_flag)
        args_no_equals += [f"--{param}", str(_value)]
        flags_append_hydra.append("+" + _flag)
        flags_override_hydra.append("++" + _flag)
        if isinstance(_value, bool):
            # omit flags if they are boolean and false
            if _value:
                flags_no_booleans.append("--" + param)
        else:
            flags_no_booleans.append("--" + _flag)
        flags_dict[param] = _value
    return {
        "args": flags,
        "args_no_equals": args_no_equals,
        "args_no_hyphens": flags_no_hyphens,
        "args_no_boolean_flags": flags_no_booleans,
        "args_json": [json.dumps(flags_dict)],
        "args_dict": flags_dict,
        "args_append_hydra": flags_append_hydra,
        "args_override_hydra": flags_override_hydra,
    }


def make_launch_sweep_entrypoint(
    args: Dict[str, Any], command: Optional[List[str]]
) -> Tuple[Optional[List[str]], Any]:
    """Use args dict from create_sweep_command_args to construct entrypoint.

    If replace is True, remove macros from entrypoint, fill them in with args
    and then return the args in separate return value.
    """
    if not command:
        return None, None

    entry_point = create_sweep_command(command)
    macro_args = {}
    for macro in args:
        mstr = "${" + macro + "}"
        if mstr in entry_point:
            idx = entry_point.index(mstr)
            # only supports 1 macro per entrypoint
            macro_args = args[macro]
            entry_point = entry_point[:idx] + entry_point[idx + 1 :]

    if len(entry_point) == 0:
        return None, macro_args

    return entry_point, macro_args


def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
    """Check if the job exists using the public api.

    Returns: True if no job is passed, or if the job exists.
    Returns: False if the job is misformatted or doesn't exist.
    """
    if not job:
        return True

    try:
        public_api.job(job)
    except Exception as e:
        wandb.termerror(f"Failed to load job. {e}")
        return False
    return True


def get_previous_args(
    run_spec: Dict[str, Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """Parse through previous scheduler run_spec.

    returns scheduler_args and settings.
    """
    scheduler_args = (
        run_spec.get("overrides", {}).get("run_config", {}).get("scheduler", {})
    )
    # also pipe through top level resource setup
    if run_spec.get("resource"):
        scheduler_args["resource"] = run_spec["resource"]
    if run_spec.get("resource_args"):
        scheduler_args["resource_args"] = run_spec["resource_args"]

    settings = run_spec.get("overrides", {}).get("run_config", {}).get("settings", {})

    return scheduler_args, settings