jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import json
import os
from typing import Any, Dict
import yaml
from ..errors import LaunchError
FILE_OVERRIDE_ENV_VAR = "WANDB_LAUNCH_FILE_OVERRIDES"
class FileOverrides:
"""Singleton that read file overrides json from environment variables."""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = object.__new__(cls)
cls._instance.overrides = {}
cls._instance.load()
return cls._instance
def load(self) -> None:
"""Load overrides from an environment variable."""
overrides = os.environ.get(FILE_OVERRIDE_ENV_VAR)
if overrides is None:
if f"{FILE_OVERRIDE_ENV_VAR}_0" in os.environ:
overrides = ""
idx = 0
while f"{FILE_OVERRIDE_ENV_VAR}_{idx}" in os.environ:
overrides += os.environ[f"{FILE_OVERRIDE_ENV_VAR}_{idx}"]
idx += 1
if overrides:
try:
contents = json.loads(overrides)
if not isinstance(contents, dict):
raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
self.overrides = contents
except json.JSONDecodeError:
raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
def config_path_is_valid(path: str) -> None:
"""Validate a config file path.
This function checks if a given config file path is valid. A valid path
should meet the following criteria:
- The path must be expressed as a relative path without any upwards path
traversal, e.g. `../config.json`.
- The file specified by the path must exist.
- The file must have a supported extension (`.json`, `.yaml`, or `.yml`).
Args:
path (str): The path to validate.
Raises:
LaunchError: If the path is not valid.
"""
if os.path.isabs(path):
raise LaunchError(
f"Invalid config path: {path}. Please provide a relative path."
)
if ".." in path:
raise LaunchError(
f"Invalid config path: {path}. Please provide a relative path "
"without any upward path traversal, e.g. `../config.json`."
)
path = os.path.normpath(path)
if not os.path.exists(path):
raise LaunchError(f"Invalid config path: {path}. File does not exist.")
if not any(path.endswith(ext) for ext in [".json", ".yaml", ".yml"]):
raise LaunchError(
f"Invalid config path: {path}. Only JSON and YAML files are supported."
)
def override_file(path: str) -> None:
"""Check for file overrides in the environment and apply them if found."""
file_overrides = FileOverrides()
if path in file_overrides.overrides:
overrides = file_overrides.overrides.get(path)
if overrides is not None:
config = _read_config_file(path)
_update_dict(config, overrides)
_write_config_file(path, config)
def _write_config_file(path: str, config: Any) -> None:
"""Write a config file to disk.
Args:
path (str): The path to the config file.
config (Any): The contents of the config file as a Python object.
Raises:
LaunchError: If the file extension is not supported.
"""
_, ext = os.path.splitext(path)
if ext == ".json":
with open(path, "w") as f:
json.dump(config, f, indent=2)
elif ext in [".yaml", ".yml"]:
with open(path, "w") as f:
yaml.safe_dump(config, f)
else:
raise LaunchError(f"Unsupported file extension: {ext}")
def _read_config_file(path: str) -> Any:
"""Read a config file from disk.
Args:
path (str): The path to the config file.
Returns:
Any: The contents of the config file as a Python object.
"""
_, ext = os.path.splitext(path)
if ext == ".json":
with open(
path,
) as f:
return json.load(f)
elif ext in [".yaml", ".yml"]:
with open(
path,
) as f:
return yaml.safe_load(f)
else:
raise LaunchError(f"Unsupported file extension: {ext}")
def _update_dict(target: Dict, source: Dict) -> None:
"""Update a dictionary with the contents of another dictionary.
Args:
target (Dict): The dictionary to update.
source (Dict): The dictionary to update from.
"""
for key, value in source.items():
if isinstance(value, dict):
if key not in target:
target[key] = {}
_update_dict(target[key], value)
else:
target[key] = value