|
from __future__ import annotations |
|
|
|
import os |
|
import secrets |
|
import socket |
|
import string |
|
|
|
import wandb |
|
|
|
from . import config |
|
from . import files as sm_files |
|
|
|
|
|
def set_run_id(run_settings: wandb.Settings) -> bool: |
|
"""Set a run ID and group when using SageMaker. |
|
|
|
Returns whether the ID and group were updated. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if os.getenv("WANDB_RUN_ID"): |
|
return False |
|
|
|
run_group = os.getenv("TRAINING_JOB_NAME") |
|
if not run_group: |
|
return False |
|
|
|
alphanumeric = string.ascii_lowercase + string.digits |
|
random = "".join(secrets.choice(alphanumeric) for _ in range(6)) |
|
|
|
host = os.getenv("CURRENT_HOST", socket.gethostname()) |
|
|
|
run_settings.run_id = f"{run_group}-{random}-{host}" |
|
run_settings.run_group = run_group |
|
return True |
|
|
|
|
|
def set_global_settings(settings: wandb.Settings) -> None: |
|
"""Set global W&B settings based on the SageMaker environment.""" |
|
if env := parse_sm_secrets(): |
|
settings.update_from_env_vars(env) |
|
|
|
|
|
|
|
|
|
|
|
sm_config = config.parse_sm_config() |
|
if api_key := sm_config.get("wandb_api_key"): |
|
settings.api_key = api_key |
|
|
|
|
|
def parse_sm_secrets() -> dict[str, str]: |
|
"""We read our api_key from secrets.env in SageMaker.""" |
|
env_dict = dict() |
|
|
|
if os.path.exists(sm_files.SM_SECRETS): |
|
for line in open(sm_files.SM_SECRETS): |
|
key, val = line.strip().split("=", 1) |
|
env_dict[key] = val |
|
return env_dict |
|
|