Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| from typing import Type, TypeVar | |
| import omegaconf | |
| from omegaconf import DictConfig, OmegaConf | |
| from pydantic import BaseModel | |
| def parse_file_config(path: str) -> DictConfig: | |
| file_cfg = OmegaConf.load(path) | |
| if not isinstance(file_cfg, DictConfig): | |
| raise ValueError( | |
| f"File paths must parse to DictConfig, but it was: {type(file_cfg)}" | |
| ) | |
| return file_cfg | |
| def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]: | |
| if "config" not in cfg: | |
| return [cfg] | |
| ordered_cfgs = [] | |
| cfg = copy.deepcopy(cfg) | |
| config_arg = cfg["config"] | |
| del cfg["config"] | |
| ordered_cfgs.append(cfg) | |
| if isinstance(config_arg, str): | |
| file_cfg = parse_file_config(config_arg) | |
| sub_configs = recursively_parse_config(file_cfg) | |
| ordered_cfgs = sub_configs + ordered_cfgs | |
| elif isinstance(config_arg, omegaconf.listconfig.ListConfig): | |
| sub_configs = [] | |
| for c in config_arg: | |
| if not isinstance(c, str): | |
| raise ValueError( | |
| f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}' | |
| ) | |
| config_to_parse = parse_file_config(c) | |
| sub_configs.extend(recursively_parse_config(config_to_parse)) | |
| ordered_cfgs = sub_configs + ordered_cfgs | |
| else: | |
| raise ValueError( | |
| f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}' | |
| ) | |
| return ordered_cfgs | |
| def parse_args_with_default( | |
| *, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None | |
| ): | |
| if cli_args is None: | |
| cli_args = OmegaConf.from_cli() | |
| assert isinstance( | |
| cli_args, DictConfig | |
| ), f"CLI Args must be a DictConfig, not {type(cli_args)}" | |
| ordered_cfgs = recursively_parse_config(cli_args) | |
| if default_cfg is not None: | |
| ordered_cfgs.insert(0, default_cfg) | |
| cfg = OmegaConf.merge(*ordered_cfgs) | |
| # TODO: Change sources to list[tuple,str, float]] so that this special case isn't needed | |
| for c in reversed(ordered_cfgs): | |
| if "data" in c and "sources" in c["data"]: | |
| cfg["data"]["sources"] = c["data"]["sources"] | |
| break | |
| return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) | |
| T = TypeVar("T", bound=BaseModel) | |
| def parse_args_to_pydantic_model( | |
| args_cls: Type[T], cli_args: DictConfig | None = None | |
| ) -> T: | |
| default_cfg = OmegaConf.create(args_cls().model_dump()) | |
| parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args) | |
| pydantic_args = args_cls.model_validate(parsed_cfg) | |
| return pydantic_args | |