Spaces:
Running
on
Zero
Running
on
Zero
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# // | |
# // Licensed under the Apache License, Version 2.0 (the "License"); | |
# // you may not use this file except in compliance with the License. | |
# // You may obtain a copy of the License at | |
# // | |
# // http://www.apache.org/licenses/LICENSE-2.0 | |
# // | |
# // Unless required by applicable law or agreed to in writing, software | |
# // distributed under the License is distributed on an "AS IS" BASIS, | |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# // See the License for the specific language governing permissions and | |
# // limitations under the License. | |
""" | |
Configuration utility functions | |
""" | |
import importlib | |
from typing import Any, Callable, List, Union | |
from omegaconf import DictConfig, ListConfig, OmegaConf | |
OmegaConf.register_new_resolver("eval", eval) | |
def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: | |
""" | |
Load a configuration. Will resolve inheritance. | |
""" | |
config = OmegaConf.load(path) | |
if argv is not None: | |
config_argv = OmegaConf.from_dotlist(argv) | |
config = OmegaConf.merge(config, config_argv) | |
config = resolve_recursive(config, resolve_inheritance) | |
return config | |
def resolve_recursive( | |
config: Any, | |
resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]], | |
) -> Any: | |
config = resolver(config) | |
if isinstance(config, DictConfig): | |
for k in config.keys(): | |
v = config.get(k) | |
if isinstance(v, (DictConfig, ListConfig)): | |
config[k] = resolve_recursive(v, resolver) | |
if isinstance(config, ListConfig): | |
for i in range(len(config)): | |
v = config.get(i) | |
if isinstance(v, (DictConfig, ListConfig)): | |
config[i] = resolve_recursive(v, resolver) | |
return config | |
def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any: | |
""" | |
Recursively resolve inheritance if the config contains: | |
__inherit__: path/to/parent.yaml or a ListConfig of such paths. | |
""" | |
if isinstance(config, DictConfig): | |
inherit = config.pop("__inherit__", None) | |
if inherit: | |
inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit] | |
parent_config = None | |
for parent_path in inherit_list: | |
assert isinstance(parent_path, str) | |
parent_config = ( | |
load_config(parent_path) | |
if parent_config is None | |
else OmegaConf.merge(parent_config, load_config(parent_path)) | |
) | |
if len(config.keys()) > 0: | |
config = OmegaConf.merge(parent_config, config) | |
else: | |
config = parent_config | |
return config | |
def import_item(path: str, name: str) -> Any: | |
""" | |
Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass | |
""" | |
return getattr(importlib.import_module(path), name) | |
def create_object(config: DictConfig) -> Any: | |
""" | |
Create an object from config. | |
The config is expected to contains the following: | |
__object__: | |
path: path.to.module | |
name: MyClass | |
args: as_config | as_params (default to as_config) | |
""" | |
item = import_item( | |
path=config.__object__.path, | |
name=config.__object__.name, | |
) | |
args = config.__object__.get("args", "as_config") | |
if args == "as_config": | |
return item(config) | |
if args == "as_params": | |
config = OmegaConf.to_object(config) | |
config.pop("__object__") | |
return item(**config) | |
raise NotImplementedError(f"Unknown args type: {args}") |