File size: 4,685 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import json
import os
from typing import Any, Dict
import yaml
from ..errors import LaunchError
FILE_OVERRIDE_ENV_VAR = "WANDB_LAUNCH_FILE_OVERRIDES"
class FileOverrides:
"""Singleton that read file overrides json from environment variables."""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = object.__new__(cls)
cls._instance.overrides = {}
cls._instance.load()
return cls._instance
def load(self) -> None:
"""Load overrides from an environment variable."""
overrides = os.environ.get(FILE_OVERRIDE_ENV_VAR)
if overrides is None:
if f"{FILE_OVERRIDE_ENV_VAR}_0" in os.environ:
overrides = ""
idx = 0
while f"{FILE_OVERRIDE_ENV_VAR}_{idx}" in os.environ:
overrides += os.environ[f"{FILE_OVERRIDE_ENV_VAR}_{idx}"]
idx += 1
if overrides:
try:
contents = json.loads(overrides)
if not isinstance(contents, dict):
raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
self.overrides = contents
except json.JSONDecodeError:
raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
def config_path_is_valid(path: str) -> None:
"""Validate a config file path.
This function checks if a given config file path is valid. A valid path
should meet the following criteria:
- The path must be expressed as a relative path without any upwards path
traversal, e.g. `../config.json`.
- The file specified by the path must exist.
- The file must have a supported extension (`.json`, `.yaml`, or `.yml`).
Args:
path (str): The path to validate.
Raises:
LaunchError: If the path is not valid.
"""
if os.path.isabs(path):
raise LaunchError(
f"Invalid config path: {path}. Please provide a relative path."
)
if ".." in path:
raise LaunchError(
f"Invalid config path: {path}. Please provide a relative path "
"without any upward path traversal, e.g. `../config.json`."
)
path = os.path.normpath(path)
if not os.path.exists(path):
raise LaunchError(f"Invalid config path: {path}. File does not exist.")
if not any(path.endswith(ext) for ext in [".json", ".yaml", ".yml"]):
raise LaunchError(
f"Invalid config path: {path}. Only JSON and YAML files are supported."
)
def override_file(path: str) -> None:
"""Check for file overrides in the environment and apply them if found."""
file_overrides = FileOverrides()
if path in file_overrides.overrides:
overrides = file_overrides.overrides.get(path)
if overrides is not None:
config = _read_config_file(path)
_update_dict(config, overrides)
_write_config_file(path, config)
def _write_config_file(path: str, config: Any) -> None:
"""Write a config file to disk.
Args:
path (str): The path to the config file.
config (Any): The contents of the config file as a Python object.
Raises:
LaunchError: If the file extension is not supported.
"""
_, ext = os.path.splitext(path)
if ext == ".json":
with open(path, "w") as f:
json.dump(config, f, indent=2)
elif ext in [".yaml", ".yml"]:
with open(path, "w") as f:
yaml.safe_dump(config, f)
else:
raise LaunchError(f"Unsupported file extension: {ext}")
def _read_config_file(path: str) -> Any:
"""Read a config file from disk.
Args:
path (str): The path to the config file.
Returns:
Any: The contents of the config file as a Python object.
"""
_, ext = os.path.splitext(path)
if ext == ".json":
with open(
path,
) as f:
return json.load(f)
elif ext in [".yaml", ".yml"]:
with open(
path,
) as f:
return yaml.safe_load(f)
else:
raise LaunchError(f"Unsupported file extension: {ext}")
def _update_dict(target: Dict, source: Dict) -> None:
"""Update a dictionary with the contents of another dictionary.
Args:
target (Dict): The dictionary to update.
source (Dict): The dictionary to update from.
"""
for key, value in source.items():
if isinstance(value, dict):
if key not in target:
target[key] = {}
_update_dict(target[key], value)
else:
target[key] = value
|