File size: 4,685 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import json
import os
from typing import Any, Dict

import yaml

from ..errors import LaunchError

FILE_OVERRIDE_ENV_VAR = "WANDB_LAUNCH_FILE_OVERRIDES"


class FileOverrides:
    """Singleton that read file overrides json from environment variables."""

    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = object.__new__(cls)
            cls._instance.overrides = {}
            cls._instance.load()
        return cls._instance

    def load(self) -> None:
        """Load overrides from an environment variable."""
        overrides = os.environ.get(FILE_OVERRIDE_ENV_VAR)
        if overrides is None:
            if f"{FILE_OVERRIDE_ENV_VAR}_0" in os.environ:
                overrides = ""
                idx = 0
                while f"{FILE_OVERRIDE_ENV_VAR}_{idx}" in os.environ:
                    overrides += os.environ[f"{FILE_OVERRIDE_ENV_VAR}_{idx}"]
                    idx += 1
        if overrides:
            try:
                contents = json.loads(overrides)
                if not isinstance(contents, dict):
                    raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
                self.overrides = contents
            except json.JSONDecodeError:
                raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")


def config_path_is_valid(path: str) -> None:
    """Validate a config file path.

    This function checks if a given config file path is valid. A valid path
    should meet the following criteria:

    - The path must be expressed as a relative path without any upwards path
      traversal, e.g. `../config.json`.
    - The file specified by the path must exist.
    - The file must have a supported extension (`.json`, `.yaml`, or `.yml`).

    Args:
        path (str): The path to validate.

    Raises:
        LaunchError: If the path is not valid.
    """
    if os.path.isabs(path):
        raise LaunchError(
            f"Invalid config path: {path}. Please provide a relative path."
        )
    if ".." in path:
        raise LaunchError(
            f"Invalid config path: {path}. Please provide a relative path "
            "without any upward path traversal, e.g. `../config.json`."
        )
    path = os.path.normpath(path)
    if not os.path.exists(path):
        raise LaunchError(f"Invalid config path: {path}. File does not exist.")
    if not any(path.endswith(ext) for ext in [".json", ".yaml", ".yml"]):
        raise LaunchError(
            f"Invalid config path: {path}. Only JSON and YAML files are supported."
        )


def override_file(path: str) -> None:
    """Check for file overrides in the environment and apply them if found."""
    file_overrides = FileOverrides()
    if path in file_overrides.overrides:
        overrides = file_overrides.overrides.get(path)
        if overrides is not None:
            config = _read_config_file(path)
            _update_dict(config, overrides)
            _write_config_file(path, config)


def _write_config_file(path: str, config: Any) -> None:
    """Write a config file to disk.

    Args:
        path (str): The path to the config file.
        config (Any): The contents of the config file as a Python object.

    Raises:
        LaunchError: If the file extension is not supported.
    """
    _, ext = os.path.splitext(path)
    if ext == ".json":
        with open(path, "w") as f:
            json.dump(config, f, indent=2)
    elif ext in [".yaml", ".yml"]:
        with open(path, "w") as f:
            yaml.safe_dump(config, f)
    else:
        raise LaunchError(f"Unsupported file extension: {ext}")


def _read_config_file(path: str) -> Any:
    """Read a config file from disk.

    Args:
        path (str): The path to the config file.

    Returns:
        Any: The contents of the config file as a Python object.
    """
    _, ext = os.path.splitext(path)
    if ext == ".json":
        with open(
            path,
        ) as f:
            return json.load(f)
    elif ext in [".yaml", ".yml"]:
        with open(
            path,
        ) as f:
            return yaml.safe_load(f)
    else:
        raise LaunchError(f"Unsupported file extension: {ext}")


def _update_dict(target: Dict, source: Dict) -> None:
    """Update a dictionary with the contents of another dictionary.

    Args:
        target (Dict): The dictionary to update.
        source (Dict): The dictionary to update from.
    """
    for key, value in source.items():
        if isinstance(value, dict):
            if key not in target:
                target[key] = {}
            _update_dict(target[key], value)
        else:
            target[key] = value