import copy import itertools from typing import ( Any, Callable, Dict, Iterable, Iterator, List, MutableSequence, Optional, Tuple, Type, Union, ) from ._utils import ( ValueKind, _is_missing_literal, _is_none, _resolve_optional, format_and_raise, get_value_kind, is_int, is_primitive_list, is_structured_config, type_str, ) from .base import Box, ContainerMetadata, Node from .basecontainer import BaseContainer from .errors import ( ConfigAttributeError, ConfigTypeError, ConfigValueError, KeyValidationError, MissingMandatoryValue, ReadonlyConfigError, ValidationError, ) class ListConfig(BaseContainer, MutableSequence[Any]): _content: Union[List[Node], None, str] def __init__( self, content: Union[List[Any], Tuple[Any, ...], "ListConfig", str, None], key: Any = None, parent: Optional[Box] = None, element_type: Union[Type[Any], Any] = Any, is_optional: bool = True, ref_type: Union[Type[Any], Any] = Any, flags: Optional[Dict[str, bool]] = None, ) -> None: try: if isinstance(content, ListConfig): if flags is None: flags = content._metadata.flags super().__init__( parent=parent, metadata=ContainerMetadata( ref_type=ref_type, object_type=list, key=key, optional=is_optional, element_type=element_type, key_type=int, flags=flags, ), ) if isinstance(content, ListConfig): metadata = copy.deepcopy(content._metadata) metadata.key = key metadata.ref_type = ref_type metadata.optional = is_optional metadata.element_type = element_type self.__dict__["_metadata"] = metadata self._set_value(value=content, flags=flags) except Exception as ex: format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex)) def _validate_get(self, key: Any, value: Any = None) -> None: if not isinstance(key, (int, slice)): raise KeyValidationError( "ListConfig indices must be integers or slices, not $KEY_TYPE" ) def _validate_set(self, key: Any, value: Any) -> None: from omegaconf import OmegaConf self._validate_get(key, value) if self._get_flag("readonly"): raise ReadonlyConfigError("ListConfig is read-only") if 0 <= key < self.__len__(): target = self._get_node(key) if target is not None: assert isinstance(target, Node) if value is None and not target._is_optional(): raise ValidationError( "$FULL_KEY is not optional and cannot be assigned None" ) vk = get_value_kind(value) if vk == ValueKind.MANDATORY_MISSING: return else: is_optional, target_type = _resolve_optional(self._metadata.element_type) value_type = OmegaConf.get_type(value) if (value_type is None and not is_optional) or ( is_structured_config(target_type) and value_type is not None and not issubclass(value_type, target_type) ): msg = ( f"Invalid type assigned: {type_str(value_type)} is not a " f"subclass of {type_str(target_type)}. value: {value}" ) raise ValidationError(msg) def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig": res = ListConfig(None) res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo) res.__dict__["_flags_cache"] = copy.deepcopy( self.__dict__["_flags_cache"], memo=memo ) src_content = self.__dict__["_content"] if isinstance(src_content, list): content_copy: List[Optional[Node]] = [] for v in src_content: old_parent = v.__dict__["_parent"] try: v.__dict__["_parent"] = None vc = copy.deepcopy(v, memo=memo) vc.__dict__["_parent"] = res content_copy.append(vc) finally: v.__dict__["_parent"] = old_parent else: # None and strings can be assigned as is content_copy = src_content res.__dict__["_content"] = content_copy res.__dict__["_parent"] = self.__dict__["_parent"] return res def copy(self) -> "ListConfig": return copy.copy(self) # hide content while inspecting in debugger def __dir__(self) -> Iterable[str]: if self._is_missing() or self._is_none(): return [] return [str(x) for x in range(0, len(self))] def __setattr__(self, key: str, value: Any) -> None: self._format_and_raise( key=key, value=value, cause=ConfigAttributeError("ListConfig does not support attribute access"), ) assert False def __getattr__(self, key: str) -> Any: # PyCharm is sometimes inspecting __members__, be sure to tell it we don't have that. if key == "__members__": raise AttributeError() if key == "__name__": raise AttributeError() if is_int(key): return self.__getitem__(int(key)) else: self._format_and_raise( key=key, value=None, cause=ConfigAttributeError( "ListConfig does not support attribute access" ), ) def __getitem__(self, index: Union[int, slice]) -> Any: try: if self._is_missing(): raise MissingMandatoryValue("ListConfig is missing") self._validate_get(index, None) if self._is_none(): raise TypeError( "ListConfig object representing None is not subscriptable" ) assert isinstance(self.__dict__["_content"], list) if isinstance(index, slice): result = [] start, stop, step = self._correct_index_params(index) for slice_idx in itertools.islice( range(0, len(self)), start, stop, step ): val = self._resolve_with_default( key=slice_idx, value=self.__dict__["_content"][slice_idx] ) result.append(val) if index.step and index.step < 0: result.reverse() return result else: return self._resolve_with_default( key=index, value=self.__dict__["_content"][index] ) except Exception as e: self._format_and_raise(key=index, value=None, cause=e) def _correct_index_params(self, index: slice) -> Tuple[int, int, int]: start = index.start stop = index.stop step = index.step if index.start and index.start < 0: start = self.__len__() + index.start if index.stop and index.stop < 0: stop = self.__len__() + index.stop if index.step and index.step < 0: step = abs(step) if start and stop: if start > stop: start, stop = stop + 1, start + 1 else: start = stop = 0 elif not start and stop: start = list(range(self.__len__() - 1, stop, -step))[0] stop = None elif start and not stop: stop = start + 1 start = (stop - 1) % step else: start = (self.__len__() - 1) % step return start, stop, step def _set_at_index(self, index: Union[int, slice], value: Any) -> None: self._set_item_impl(index, value) def __setitem__(self, index: Union[int, slice], value: Any) -> None: try: if isinstance(index, slice): _ = iter(value) # check iterable self_indices = index.indices(len(self)) indexes = range(*self_indices) # Ensure lengths match for extended slice assignment if index.step not in (None, 1): if len(indexes) != len(value): raise ValueError( f"attempt to assign sequence of size {len(value)}" f" to extended slice of size {len(indexes)}" ) # Initialize insertion offsets for empty slices if len(indexes) == 0: curr_index = self_indices[0] - 1 val_i = -1 work_copy = self.copy() # For atomicity manipulate a copy # Delete and optionally replace non empty slices only_removed = 0 for val_i, i in enumerate(indexes): curr_index = i - only_removed del work_copy[curr_index] if val_i < len(value): work_copy.insert(curr_index, value[val_i]) else: only_removed += 1 # Insert any remaining input items for val_i in range(val_i + 1, len(value)): curr_index += 1 work_copy.insert(curr_index, value[val_i]) # Reinitialize self with work_copy self.clear() self.extend(work_copy) else: self._set_at_index(index, value) except Exception as e: self._format_and_raise(key=index, value=value, cause=e) def append(self, item: Any) -> None: content = self.__dict__["_content"] index = len(content) content.append(None) try: self._set_item_impl(index, item) except Exception as e: del content[index] self._format_and_raise(key=index, value=item, cause=e) assert False def _update_keys(self) -> None: for i in range(len(self)): node = self._get_node(i) if node is not None: assert isinstance(node, Node) node._metadata.key = i def insert(self, index: int, item: Any) -> None: from omegaconf.omegaconf import _maybe_wrap try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot insert into a read-only ListConfig") if self._is_none(): raise TypeError( "Cannot insert into ListConfig object representing None" ) if self._is_missing(): raise MissingMandatoryValue("Cannot insert into missing ListConfig") try: assert isinstance(self.__dict__["_content"], list) # insert place holder self.__dict__["_content"].insert(index, None) is_optional, ref_type = _resolve_optional(self._metadata.element_type) node = _maybe_wrap( ref_type=ref_type, key=index, value=item, is_optional=is_optional, parent=self, ) self._validate_set(key=index, value=node) self._set_at_index(index, node) self._update_keys() except Exception: del self.__dict__["_content"][index] self._update_keys() raise except Exception as e: self._format_and_raise(key=index, value=item, cause=e) assert False def extend(self, lst: Iterable[Any]) -> None: assert isinstance(lst, (tuple, list, ListConfig)) for x in lst: self.append(x) def remove(self, x: Any) -> None: del self[self.index(x)] def __delitem__(self, key: Union[int, slice]) -> None: if self._get_flag("readonly"): self._format_and_raise( key=key, value=None, cause=ReadonlyConfigError( "Cannot delete item from read-only ListConfig" ), ) del self.__dict__["_content"][key] self._update_keys() def clear(self) -> None: del self[:] def index( self, x: Any, start: Optional[int] = None, end: Optional[int] = None ) -> int: if start is None: start = 0 if end is None: end = len(self) assert start >= 0 assert end <= len(self) found_idx = -1 for idx in range(start, end): item = self[idx] if x == item: found_idx = idx break if found_idx != -1: return found_idx else: self._format_and_raise( key=None, value=None, cause=ConfigValueError("Item not found in ListConfig"), ) assert False def count(self, x: Any) -> int: c = 0 for item in self: if item == x: c = c + 1 return c def _get_node( self, key: Union[int, slice], validate_access: bool = True, validate_key: bool = True, throw_on_missing_value: bool = False, throw_on_missing_key: bool = False, ) -> Union[Optional[Node], List[Optional[Node]]]: try: if self._is_none(): raise TypeError( "Cannot get_node from a ListConfig object representing None" ) if self._is_missing(): raise MissingMandatoryValue("Cannot get_node from a missing ListConfig") assert isinstance(self.__dict__["_content"], list) if validate_access: self._validate_get(key) value = self.__dict__["_content"][key] if value is not None: if isinstance(key, slice): assert isinstance(value, list) for v in value: if throw_on_missing_value and v._is_missing(): raise MissingMandatoryValue("Missing mandatory value") else: assert isinstance(value, Node) if throw_on_missing_value and value._is_missing(): raise MissingMandatoryValue("Missing mandatory value: $KEY") return value except (IndexError, TypeError, MissingMandatoryValue, KeyValidationError) as e: if isinstance(e, MissingMandatoryValue) and throw_on_missing_value: raise if validate_access: self._format_and_raise(key=key, value=None, cause=e) assert False else: return None def get(self, index: int, default_value: Any = None) -> Any: try: if self._is_none(): raise TypeError("Cannot get from a ListConfig object representing None") if self._is_missing(): raise MissingMandatoryValue("Cannot get from a missing ListConfig") self._validate_get(index, None) assert isinstance(self.__dict__["_content"], list) return self._resolve_with_default( key=index, value=self.__dict__["_content"][index], default_value=default_value, ) except Exception as e: self._format_and_raise(key=index, value=None, cause=e) assert False def pop(self, index: int = -1) -> Any: try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot pop from read-only ListConfig") if self._is_none(): raise TypeError("Cannot pop from a ListConfig object representing None") if self._is_missing(): raise MissingMandatoryValue("Cannot pop from a missing ListConfig") assert isinstance(self.__dict__["_content"], list) node = self._get_child(index) assert isinstance(node, Node) ret = self._resolve_with_default(key=index, value=node, default_value=None) del self.__dict__["_content"][index] self._update_keys() return ret except KeyValidationError as e: self._format_and_raise( key=index, value=None, cause=e, type_override=ConfigTypeError ) assert False except Exception as e: self._format_and_raise(key=index, value=None, cause=e) assert False def sort( self, key: Optional[Callable[[Any], Any]] = None, reverse: bool = False ) -> None: try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot sort a read-only ListConfig") if self._is_none(): raise TypeError("Cannot sort a ListConfig object representing None") if self._is_missing(): raise MissingMandatoryValue("Cannot sort a missing ListConfig") if key is None: def key1(x: Any) -> Any: return x._value() else: def key1(x: Any) -> Any: return key(x._value()) # type: ignore assert isinstance(self.__dict__["_content"], list) self.__dict__["_content"].sort(key=key1, reverse=reverse) except Exception as e: self._format_and_raise(key=None, value=None, cause=e) assert False def __eq__(self, other: Any) -> bool: if isinstance(other, (list, tuple)) or other is None: other = ListConfig(other, flags={"allow_objects": True}) return ListConfig._list_eq(self, other) if other is None or isinstance(other, ListConfig): return ListConfig._list_eq(self, other) if self._is_missing(): return _is_missing_literal(other) return NotImplemented def __ne__(self, other: Any) -> bool: x = self.__eq__(other) if x is not NotImplemented: return not x return NotImplemented def __hash__(self) -> int: return hash(str(self)) def __iter__(self) -> Iterator[Any]: return self._iter_ex(resolve=True) class ListIterator(Iterator[Any]): def __init__(self, lst: Any, resolve: bool) -> None: self.resolve = resolve self.iterator = iter(lst.__dict__["_content"]) self.index = 0 from .nodes import ValueNode self.ValueNode = ValueNode def __next__(self) -> Any: x = next(self.iterator) if self.resolve: x = x._dereference_node() if x._is_missing(): raise MissingMandatoryValue(f"Missing value at index {self.index}") self.index = self.index + 1 if isinstance(x, self.ValueNode): return x._value() else: # Must be omegaconf.Container. not checking for perf reasons. if x._is_none(): return None return x def __repr__(self) -> str: # pragma: no cover return f"ListConfig.ListIterator(resolve={self.resolve})" def _iter_ex(self, resolve: bool) -> Iterator[Any]: try: if self._is_none(): raise TypeError("Cannot iterate a ListConfig object representing None") if self._is_missing(): raise MissingMandatoryValue("Cannot iterate a missing ListConfig") return ListConfig.ListIterator(self, resolve) except (TypeError, MissingMandatoryValue) as e: self._format_and_raise(key=None, value=None, cause=e) assert False def __add__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig": # res is sharing this list's parent to allow interpolation to work as expected res = ListConfig(parent=self._get_parent(), content=[]) res.extend(self) res.extend(other) return res def __radd__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig": # res is sharing this list's parent to allow interpolation to work as expected res = ListConfig(parent=self._get_parent(), content=[]) res.extend(other) res.extend(self) return res def __iadd__(self, other: Iterable[Any]) -> "ListConfig": self.extend(other) return self def __contains__(self, item: Any) -> bool: if self._is_none(): raise TypeError( "Cannot check if an item is in a ListConfig object representing None" ) if self._is_missing(): raise MissingMandatoryValue( "Cannot check if an item is in missing ListConfig" ) lst = self.__dict__["_content"] for x in lst: x = x._dereference_node() if x == item: return True return False def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None: try: previous_content = self.__dict__["_content"] previous_metadata = self.__dict__["_metadata"] self._set_value_impl(value, flags) except Exception as e: self.__dict__["_content"] = previous_content self.__dict__["_metadata"] = previous_metadata raise e def _set_value_impl( self, value: Any, flags: Optional[Dict[str, bool]] = None ) -> None: from omegaconf import MISSING, flag_override if flags is None: flags = {} vk = get_value_kind(value, strict_interpolation_validation=True) if _is_none(value): if not self._is_optional(): raise ValidationError( "Non optional ListConfig cannot be constructed from None" ) self.__dict__["_content"] = None self._metadata.object_type = None elif vk is ValueKind.MANDATORY_MISSING: self.__dict__["_content"] = MISSING self._metadata.object_type = None elif vk == ValueKind.INTERPOLATION: self.__dict__["_content"] = value self._metadata.object_type = None else: if not (is_primitive_list(value) or isinstance(value, ListConfig)): type_ = type(value) msg = f"Invalid value assigned: {type_.__name__} is not a ListConfig, list or tuple." raise ValidationError(msg) self.__dict__["_content"] = [] if isinstance(value, ListConfig): self._metadata.flags = copy.deepcopy(flags) # disable struct and readonly for the construction phase # retaining other flags like allow_objects. The real flags are restored at the end of this function with flag_override(self, ["struct", "readonly"], False): for item in value._iter_ex(resolve=False): self.append(item) elif is_primitive_list(value): with flag_override(self, ["struct", "readonly"], False): for item in value: self.append(item) self._metadata.object_type = list @staticmethod def _list_eq(l1: Optional["ListConfig"], l2: Optional["ListConfig"]) -> bool: l1_none = l1.__dict__["_content"] is None l2_none = l2.__dict__["_content"] is None if l1_none and l2_none: return True if l1_none != l2_none: return False assert isinstance(l1, ListConfig) assert isinstance(l2, ListConfig) if len(l1) != len(l2): return False for i in range(len(l1)): if not BaseContainer._item_eq(l1, i, l2, i): return False return True