|
import json |
|
import os |
|
from typing import Any, Dict |
|
|
|
import yaml |
|
|
|
from ..errors import LaunchError |
|
|
|
FILE_OVERRIDE_ENV_VAR = "WANDB_LAUNCH_FILE_OVERRIDES" |
|
|
|
|
|
class FileOverrides: |
|
"""Singleton that read file overrides json from environment variables.""" |
|
|
|
_instance = None |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
cls._instance = object.__new__(cls) |
|
cls._instance.overrides = {} |
|
cls._instance.load() |
|
return cls._instance |
|
|
|
def load(self) -> None: |
|
"""Load overrides from an environment variable.""" |
|
overrides = os.environ.get(FILE_OVERRIDE_ENV_VAR) |
|
if overrides is None: |
|
if f"{FILE_OVERRIDE_ENV_VAR}_0" in os.environ: |
|
overrides = "" |
|
idx = 0 |
|
while f"{FILE_OVERRIDE_ENV_VAR}_{idx}" in os.environ: |
|
overrides += os.environ[f"{FILE_OVERRIDE_ENV_VAR}_{idx}"] |
|
idx += 1 |
|
if overrides: |
|
try: |
|
contents = json.loads(overrides) |
|
if not isinstance(contents, dict): |
|
raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}") |
|
self.overrides = contents |
|
except json.JSONDecodeError: |
|
raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}") |
|
|
|
|
|
def config_path_is_valid(path: str) -> None: |
|
"""Validate a config file path. |
|
|
|
This function checks if a given config file path is valid. A valid path |
|
should meet the following criteria: |
|
|
|
- The path must be expressed as a relative path without any upwards path |
|
traversal, e.g. `../config.json`. |
|
- The file specified by the path must exist. |
|
- The file must have a supported extension (`.json`, `.yaml`, or `.yml`). |
|
|
|
Args: |
|
path (str): The path to validate. |
|
|
|
Raises: |
|
LaunchError: If the path is not valid. |
|
""" |
|
if os.path.isabs(path): |
|
raise LaunchError( |
|
f"Invalid config path: {path}. Please provide a relative path." |
|
) |
|
if ".." in path: |
|
raise LaunchError( |
|
f"Invalid config path: {path}. Please provide a relative path " |
|
"without any upward path traversal, e.g. `../config.json`." |
|
) |
|
path = os.path.normpath(path) |
|
if not os.path.exists(path): |
|
raise LaunchError(f"Invalid config path: {path}. File does not exist.") |
|
if not any(path.endswith(ext) for ext in [".json", ".yaml", ".yml"]): |
|
raise LaunchError( |
|
f"Invalid config path: {path}. Only JSON and YAML files are supported." |
|
) |
|
|
|
|
|
def override_file(path: str) -> None: |
|
"""Check for file overrides in the environment and apply them if found.""" |
|
file_overrides = FileOverrides() |
|
if path in file_overrides.overrides: |
|
overrides = file_overrides.overrides.get(path) |
|
if overrides is not None: |
|
config = _read_config_file(path) |
|
_update_dict(config, overrides) |
|
_write_config_file(path, config) |
|
|
|
|
|
def _write_config_file(path: str, config: Any) -> None: |
|
"""Write a config file to disk. |
|
|
|
Args: |
|
path (str): The path to the config file. |
|
config (Any): The contents of the config file as a Python object. |
|
|
|
Raises: |
|
LaunchError: If the file extension is not supported. |
|
""" |
|
_, ext = os.path.splitext(path) |
|
if ext == ".json": |
|
with open(path, "w") as f: |
|
json.dump(config, f, indent=2) |
|
elif ext in [".yaml", ".yml"]: |
|
with open(path, "w") as f: |
|
yaml.safe_dump(config, f) |
|
else: |
|
raise LaunchError(f"Unsupported file extension: {ext}") |
|
|
|
|
|
def _read_config_file(path: str) -> Any: |
|
"""Read a config file from disk. |
|
|
|
Args: |
|
path (str): The path to the config file. |
|
|
|
Returns: |
|
Any: The contents of the config file as a Python object. |
|
""" |
|
_, ext = os.path.splitext(path) |
|
if ext == ".json": |
|
with open( |
|
path, |
|
) as f: |
|
return json.load(f) |
|
elif ext in [".yaml", ".yml"]: |
|
with open( |
|
path, |
|
) as f: |
|
return yaml.safe_load(f) |
|
else: |
|
raise LaunchError(f"Unsupported file extension: {ext}") |
|
|
|
|
|
def _update_dict(target: Dict, source: Dict) -> None: |
|
"""Update a dictionary with the contents of another dictionary. |
|
|
|
Args: |
|
target (Dict): The dictionary to update. |
|
source (Dict): The dictionary to update from. |
|
""" |
|
for key, value in source.items(): |
|
if isinstance(value, dict): |
|
if key not in target: |
|
target[key] = {} |
|
_update_dict(target[key], value) |
|
else: |
|
target[key] = value |
|
|