import copy from enum import Enum from typing import ( Any, Dict, ItemsView, Iterable, Iterator, KeysView, List, MutableMapping, Optional, Sequence, Tuple, Type, Union, ) from ._utils import ( _DEFAULT_MARKER_, ValueKind, _get_value, _is_interpolation, _is_missing_literal, _is_missing_value, _is_none, _resolve_optional, _valid_dict_key_annotation_type, format_and_raise, get_structured_config_data, get_structured_config_init_field_names, get_type_of, get_value_kind, is_container_annotation, is_dict, is_primitive_dict, is_structured_config, is_structured_config_frozen, type_str, ) from .base import Box, Container, ContainerMetadata, DictKeyType, Node from .basecontainer import BaseContainer from .errors import ( ConfigAttributeError, ConfigKeyError, ConfigTypeError, InterpolationResolutionError, KeyValidationError, MissingMandatoryValue, OmegaConfBaseException, ReadonlyConfigError, ValidationError, ) from .nodes import EnumNode, ValueNode class DictConfig(BaseContainer, MutableMapping[Any, Any]): _metadata: ContainerMetadata _content: Union[Dict[DictKeyType, Node], None, str] def __init__( self, content: Union[Dict[DictKeyType, Any], "DictConfig", Any], key: Any = None, parent: Optional[Box] = None, ref_type: Union[Any, Type[Any]] = Any, key_type: Union[Any, Type[Any]] = Any, element_type: Union[Any, Type[Any]] = Any, is_optional: bool = True, flags: Optional[Dict[str, bool]] = None, ) -> None: try: if isinstance(content, DictConfig): if flags is None: flags = content._metadata.flags super().__init__( parent=parent, metadata=ContainerMetadata( key=key, optional=is_optional, ref_type=ref_type, object_type=dict, key_type=key_type, element_type=element_type, flags=flags, ), ) if not _valid_dict_key_annotation_type(key_type): raise KeyValidationError(f"Unsupported key type {key_type}") if is_structured_config(content) or is_structured_config(ref_type): self._set_value(content, flags=flags) if is_structured_config_frozen(content) or is_structured_config_frozen( ref_type ): self._set_flag("readonly", True) else: if isinstance(content, DictConfig): metadata = copy.deepcopy(content._metadata) metadata.key = key metadata.ref_type = ref_type metadata.optional = is_optional metadata.element_type = element_type metadata.key_type = key_type self.__dict__["_metadata"] = metadata self._set_value(content, flags=flags) except Exception as ex: format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex)) def __deepcopy__(self, memo: Dict[int, Any]) -> "DictConfig": res = DictConfig(None) res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo) res.__dict__["_flags_cache"] = copy.deepcopy( self.__dict__["_flags_cache"], memo=memo ) src_content = self.__dict__["_content"] if isinstance(src_content, dict): content_copy = {} for k, v in src_content.items(): old_parent = v.__dict__["_parent"] try: v.__dict__["_parent"] = None vc = copy.deepcopy(v, memo=memo) vc.__dict__["_parent"] = res content_copy[k] = vc finally: v.__dict__["_parent"] = old_parent else: # None and strings can be assigned as is content_copy = src_content res.__dict__["_content"] = content_copy # parent is retained, but not copied res.__dict__["_parent"] = self.__dict__["_parent"] return res def copy(self) -> "DictConfig": return copy.copy(self) def _is_typed(self) -> bool: return self._metadata.object_type not in (Any, None) and not is_dict( self._metadata.object_type ) def _validate_get(self, key: Any, value: Any = None) -> None: is_typed = self._is_typed() is_struct = self._get_flag("struct") is True if key not in self.__dict__["_content"]: if is_typed: # do not raise an exception if struct is explicitly set to False if self._get_node_flag("struct") is False: return if is_typed or is_struct: if is_typed: assert self._metadata.object_type not in (dict, None) msg = f"Key '{key}' not in '{self._metadata.object_type.__name__}'" else: msg = f"Key '{key}' is not in struct" self._format_and_raise( key=key, value=value, cause=ConfigAttributeError(msg) ) def _validate_set(self, key: Any, value: Any) -> None: from omegaconf import OmegaConf vk = get_value_kind(value) if vk == ValueKind.INTERPOLATION: return if _is_none(value): self._validate_non_optional(key, value) return if vk == ValueKind.MANDATORY_MISSING or value is None: return target = self._get_node(key) if key is not None else self target_has_ref_type = isinstance( target, DictConfig ) and target._metadata.ref_type not in (Any, dict) is_valid_target = target is None or not target_has_ref_type if is_valid_target: return assert isinstance(target, Node) target_type = target._metadata.ref_type value_type = OmegaConf.get_type(value) if is_dict(value_type) and is_dict(target_type): return if is_container_annotation(target_type) and not is_container_annotation( value_type ): raise ValidationError( f"Cannot assign {type_str(value_type)} to {type_str(target_type)}" ) if target_type is not None and value_type is not None: origin = getattr(target_type, "__origin__", target_type) if not issubclass(value_type, origin): self._raise_invalid_value(value, value_type, target_type) def _validate_merge(self, value: Any) -> None: from omegaconf import OmegaConf dest = self src = value self._validate_non_optional(None, src) dest_obj_type = OmegaConf.get_type(dest) src_obj_type = OmegaConf.get_type(src) if dest._is_missing() and src._metadata.object_type not in (dict, None): self._validate_set(key=None, value=_get_value(src)) if src._is_missing(): return validation_error = ( dest_obj_type is not None and src_obj_type is not None and is_structured_config(dest_obj_type) and not src._is_none() and not is_dict(src_obj_type) and not issubclass(src_obj_type, dest_obj_type) ) if validation_error: msg = ( f"Merge error: {type_str(src_obj_type)} is not a " f"subclass of {type_str(dest_obj_type)}. value: {src}" ) raise ValidationError(msg) def _validate_non_optional(self, key: Optional[DictKeyType], value: Any) -> None: if _is_none(value, resolve=True, throw_on_resolution_failure=False): if key is not None: child = self._get_node(key) if child is not None: assert isinstance(child, Node) field_is_optional = child._is_optional() else: field_is_optional, _ = _resolve_optional( self._metadata.element_type ) else: field_is_optional = self._is_optional() if not field_is_optional: self._format_and_raise( key=key, value=value, cause=ValidationError("field '$FULL_KEY' is not Optional"), ) def _raise_invalid_value( self, value: Any, value_type: Any, target_type: Any ) -> None: assert value_type is not None assert target_type is not None msg = ( f"Invalid type assigned: {type_str(value_type)} is not a " f"subclass of {type_str(target_type)}. value: {value}" ) raise ValidationError(msg) def _validate_and_normalize_key(self, key: Any) -> DictKeyType: return self._s_validate_and_normalize_key(self._metadata.key_type, key) def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType: if key_type is Any: for t in DictKeyType.__args__: # type: ignore if isinstance(key, t): return key # type: ignore raise KeyValidationError("Incompatible key type '$KEY_TYPE'") elif key_type is bool and key in [0, 1]: # Python treats True as 1 and False as 0 when used as dict keys # assert hash(0) == hash(False) # assert hash(1) == hash(True) return bool(key) elif key_type in (str, bytes, int, float, bool): # primitive type if not isinstance(key, key_type): raise KeyValidationError( f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})" ) return key # type: ignore elif issubclass(key_type, Enum): try: return EnumNode.validate_and_convert_to_enum(key_type, key) except ValidationError: valid = ", ".join([x for x in key_type.__members__.keys()]) raise KeyValidationError( f"Key '$KEY' is incompatible with the enum type '{key_type.__name__}', valid: [{valid}]" ) else: assert False, f"Unsupported key type {key_type}" def __setitem__(self, key: DictKeyType, value: Any) -> None: try: self.__set_impl(key=key, value=value) except AttributeError as e: self._format_and_raise( key=key, value=value, type_override=ConfigKeyError, cause=e ) except Exception as e: self._format_and_raise(key=key, value=value, cause=e) def __set_impl(self, key: DictKeyType, value: Any) -> None: key = self._validate_and_normalize_key(key) self._set_item_impl(key, value) # hide content while inspecting in debugger def __dir__(self) -> Iterable[str]: if self._is_missing() or self._is_none(): return [] return self.__dict__["_content"].keys() # type: ignore def __setattr__(self, key: str, value: Any) -> None: """ Allow assigning attributes to DictConfig :param key: :param value: :return: """ try: self.__set_impl(key, value) except Exception as e: if isinstance(e, OmegaConfBaseException) and e._initialized: raise e self._format_and_raise(key=key, value=value, cause=e) assert False def __getattr__(self, key: str) -> Any: """ Allow accessing dictionary values as attributes :param key: :return: """ if key == "__name__": raise AttributeError() try: return self._get_impl( key=key, default_value=_DEFAULT_MARKER_, validate_key=False ) except ConfigKeyError as e: self._format_and_raise( key=key, value=None, cause=e, type_override=ConfigAttributeError ) except Exception as e: self._format_and_raise(key=key, value=None, cause=e) def __getitem__(self, key: DictKeyType) -> Any: """ Allow map style access :param key: :return: """ try: return self._get_impl(key=key, default_value=_DEFAULT_MARKER_) except AttributeError as e: self._format_and_raise( key=key, value=None, cause=e, type_override=ConfigKeyError ) except Exception as e: self._format_and_raise(key=key, value=None, cause=e) def __delattr__(self, key: str) -> None: """ Allow deleting dictionary values as attributes :param key: :return: """ if self._get_flag("readonly"): self._format_and_raise( key=key, value=None, cause=ReadonlyConfigError( "DictConfig in read-only mode does not support deletion" ), ) try: del self.__dict__["_content"][key] except KeyError: msg = "Attribute not found: '$KEY'" self._format_and_raise(key=key, value=None, cause=ConfigAttributeError(msg)) def __delitem__(self, key: DictKeyType) -> None: key = self._validate_and_normalize_key(key) if self._get_flag("readonly"): self._format_and_raise( key=key, value=None, cause=ReadonlyConfigError( "DictConfig in read-only mode does not support deletion" ), ) if self._get_flag("struct"): self._format_and_raise( key=key, value=None, cause=ConfigTypeError( "DictConfig in struct mode does not support deletion" ), ) if self._is_typed() and self._get_node_flag("struct") is not False: self._format_and_raise( key=key, value=None, cause=ConfigTypeError( f"{type_str(self._metadata.object_type)} (DictConfig) does not support deletion" ), ) try: del self.__dict__["_content"][key] except KeyError: msg = "Key not found: '$KEY'" self._format_and_raise(key=key, value=None, cause=ConfigKeyError(msg)) def get(self, key: DictKeyType, default_value: Any = None) -> Any: """Return the value for `key` if `key` is in the dictionary, else `default_value` (defaulting to `None`).""" try: return self._get_impl(key=key, default_value=default_value) except KeyValidationError as e: self._format_and_raise(key=key, value=None, cause=e) def _get_impl( self, key: DictKeyType, default_value: Any, validate_key: bool = True ) -> Any: try: node = self._get_child( key=key, throw_on_missing_key=True, validate_key=validate_key ) except (ConfigAttributeError, ConfigKeyError): if default_value is not _DEFAULT_MARKER_: return default_value else: raise assert isinstance(node, Node) return self._resolve_with_default( key=key, value=node, default_value=default_value ) def _get_node( self, key: DictKeyType, validate_access: bool = True, validate_key: bool = True, throw_on_missing_value: bool = False, throw_on_missing_key: bool = False, ) -> Optional[Node]: try: key = self._validate_and_normalize_key(key) except KeyValidationError: if validate_access and validate_key: raise else: if throw_on_missing_key: raise ConfigAttributeError else: return None if validate_access: self._validate_get(key) value: Optional[Node] = self.__dict__["_content"].get(key) if value is None: if throw_on_missing_key: raise ConfigKeyError(f"Missing key {key!s}") elif throw_on_missing_value and value._is_missing(): raise MissingMandatoryValue("Missing mandatory value: $KEY") return value def pop(self, key: DictKeyType, default: Any = _DEFAULT_MARKER_) -> Any: try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot pop from read-only node") if self._get_flag("struct"): raise ConfigTypeError("DictConfig in struct mode does not support pop") if self._is_typed() and self._get_node_flag("struct") is not False: raise ConfigTypeError( f"{type_str(self._metadata.object_type)} (DictConfig) does not support pop" ) key = self._validate_and_normalize_key(key) node = self._get_child(key=key, validate_access=False) if node is not None: assert isinstance(node, Node) value = self._resolve_with_default( key=key, value=node, default_value=default ) del self[key] return value else: if default is not _DEFAULT_MARKER_: return default else: full = self._get_full_key(key=key) if full != key: raise ConfigKeyError( f"Key not found: '{key!s}' (path: '{full}')" ) else: raise ConfigKeyError(f"Key not found: '{key!s}'") except Exception as e: self._format_and_raise(key=key, value=None, cause=e) def keys(self) -> KeysView[DictKeyType]: if self._is_missing() or self._is_interpolation() or self._is_none(): return {}.keys() ret = self.__dict__["_content"].keys() assert isinstance(ret, KeysView) return ret def __contains__(self, key: object) -> bool: """ A key is contained in a DictConfig if there is an associated value and it is not a mandatory missing value ('???'). :param key: :return: """ try: key = self._validate_and_normalize_key(key) except KeyValidationError: return False try: node = self._get_child(key) assert node is None or isinstance(node, Node) except (KeyError, AttributeError): node = None if node is None: return False else: try: self._resolve_with_default(key=key, value=node) return True except InterpolationResolutionError: # Interpolations that fail count as existing. return True except MissingMandatoryValue: # Missing values count as *not* existing. return False def __iter__(self) -> Iterator[DictKeyType]: return iter(self.keys()) def items(self) -> ItemsView[DictKeyType, Any]: return dict(self.items_ex(resolve=True, keys=None)).items() def setdefault(self, key: DictKeyType, default: Any = None) -> Any: if key in self: ret = self.__getitem__(key) else: ret = default self.__setitem__(key, default) return ret def items_ex( self, resolve: bool = True, keys: Optional[Sequence[DictKeyType]] = None ) -> List[Tuple[DictKeyType, Any]]: items: List[Tuple[DictKeyType, Any]] = [] if self._is_none(): self._format_and_raise( key=None, value=None, cause=TypeError("Cannot iterate a DictConfig object representing None"), ) if self._is_missing(): raise MissingMandatoryValue("Cannot iterate a missing DictConfig") for key in self.keys(): if resolve: value = self[key] else: value = self.__dict__["_content"][key] if isinstance(value, ValueNode): value = value._value() if keys is None or key in keys: items.append((key, value)) return items def __eq__(self, other: Any) -> bool: if other is None: return self.__dict__["_content"] is None if is_primitive_dict(other) or is_structured_config(other): other = DictConfig(other, flags={"allow_objects": True}) return DictConfig._dict_conf_eq(self, other) if isinstance(other, DictConfig): return DictConfig._dict_conf_eq(self, other) if self._is_missing(): return _is_missing_literal(other) return NotImplemented def __ne__(self, other: Any) -> bool: x = self.__eq__(other) if x is not NotImplemented: return not x return NotImplemented def __hash__(self) -> int: return hash(str(self)) def _promote(self, type_or_prototype: Optional[Type[Any]]) -> None: """ Retypes a node. This should only be used in rare circumstances, where you want to dynamically change the runtime structured-type of a DictConfig. It will change the type and add the additional fields based on the input class or object """ if type_or_prototype is None: return if not is_structured_config(type_or_prototype): raise ValueError(f"Expected structured config class: {type_or_prototype}") from omegaconf import OmegaConf proto: DictConfig = OmegaConf.structured(type_or_prototype) object_type = proto._metadata.object_type # remove the type to prevent assignment validation from rejecting the promotion. proto._metadata.object_type = None self.merge_with(proto) # restore the type. self._metadata.object_type = object_type def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None: try: previous_content = self.__dict__["_content"] self._set_value_impl(value, flags) except Exception as e: self.__dict__["_content"] = previous_content raise e def _set_value_impl( self, value: Any, flags: Optional[Dict[str, bool]] = None ) -> None: from omegaconf import MISSING, flag_override if flags is None: flags = {} assert not isinstance(value, ValueNode) self._validate_set(key=None, value=value) if _is_none(value, resolve=True): self.__dict__["_content"] = None self._metadata.object_type = None elif _is_interpolation(value, strict_interpolation_validation=True): self.__dict__["_content"] = value self._metadata.object_type = None elif _is_missing_value(value): self.__dict__["_content"] = MISSING self._metadata.object_type = None else: self.__dict__["_content"] = {} if is_structured_config(value): self._metadata.object_type = None ao = self._get_flag("allow_objects") data = get_structured_config_data(value, allow_objects=ao) with flag_override(self, ["struct", "readonly"], False): for k, v in data.items(): self.__setitem__(k, v) self._metadata.object_type = get_type_of(value) elif isinstance(value, DictConfig): self._metadata.flags = copy.deepcopy(flags) with flag_override(self, ["struct", "readonly"], False): for k, v in value.__dict__["_content"].items(): self.__setitem__(k, v) self._metadata.object_type = value._metadata.object_type elif isinstance(value, dict): with flag_override(self, ["struct", "readonly"], False): for k, v in value.items(): self.__setitem__(k, v) self._metadata.object_type = dict else: # pragma: no cover msg = f"Unsupported value type: {value}" raise ValidationError(msg) @staticmethod def _dict_conf_eq(d1: "DictConfig", d2: "DictConfig") -> bool: d1_none = d1.__dict__["_content"] is None d2_none = d2.__dict__["_content"] is None if d1_none and d2_none: return True if d1_none != d2_none: return False assert isinstance(d1, DictConfig) assert isinstance(d2, DictConfig) if len(d1) != len(d2): return False if d1._is_missing() or d2._is_missing(): return d1._is_missing() is d2._is_missing() for k, v in d1.items_ex(resolve=False): if k not in d2.__dict__["_content"]: return False if not BaseContainer._item_eq(d1, k, d2, k): return False return True def _to_object(self) -> Any: """ Instantiate an instance of `self._metadata.object_type`. This requires `self` to be a structured config. Nested subconfigs are converted by calling `OmegaConf.to_object`. """ from omegaconf import OmegaConf object_type = self._metadata.object_type assert is_structured_config(object_type) init_field_names = set(get_structured_config_init_field_names(object_type)) init_field_items: Dict[str, Any] = {} non_init_field_items: Dict[str, Any] = {} for k in self.keys(): assert isinstance(k, str) node = self._get_child(k) assert isinstance(node, Node) try: node = node._dereference_node() except InterpolationResolutionError as e: self._format_and_raise(key=k, value=None, cause=e) if node._is_missing(): if k not in init_field_names: continue # MISSING is ignored for init=False fields self._format_and_raise( key=k, value=None, cause=MissingMandatoryValue( "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY" ), ) if isinstance(node, Container): v = OmegaConf.to_object(node) else: v = node._value() if k in init_field_names: init_field_items[k] = v else: non_init_field_items[k] = v try: result = object_type(**init_field_items) except TypeError as exc: self._format_and_raise( key=None, value=None, cause=exc, msg="Could not create instance of `$OBJECT_TYPE`: " + str(exc), ) for k, v in non_init_field_items.items(): setattr(result, k, v) return result