|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union, cast |
|
|
|
import numpy as np |
|
import torch |
|
from torch.nn.utils.rnn import PackedSequence |
|
|
|
"""Useful functions to deal with tensor types with other python container types.""" |
|
|
|
|
|
def apply_to_type( |
|
type_fn: Callable, fn: Callable, container: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set, NamedTuple] |
|
) -> Any: |
|
"""Recursively apply to all objects in different kinds of container types that matches a type function.""" |
|
|
|
def _apply(x: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set]) -> Any: |
|
if type_fn(x): |
|
return fn(x) |
|
elif isinstance(x, OrderedDict): |
|
od = x.__class__() |
|
for key, value in x.items(): |
|
od[key] = _apply(value) |
|
return od |
|
elif isinstance(x, PackedSequence): |
|
_apply(x.data) |
|
return x |
|
elif isinstance(x, dict): |
|
return {key: _apply(value) for key, value in x.items()} |
|
elif isinstance(x, list): |
|
return [_apply(x) for x in x] |
|
elif isinstance(x, tuple): |
|
f = getattr(x, "_fields", None) |
|
if f is None: |
|
return tuple(_apply(x) for x in x) |
|
else: |
|
assert isinstance(f, tuple), "This needs to be a namedtuple" |
|
|
|
x = cast(NamedTuple, x) |
|
_dict: Dict[str, Any] = x._asdict() |
|
_dict = {key: _apply(value) for key, value in _dict.items()} |
|
return type(x)(**_dict) |
|
elif isinstance(x, set): |
|
return {_apply(x) for x in x} |
|
else: |
|
return x |
|
|
|
return _apply(container) |
|
|
|
|
|
def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: |
|
"""Recursively apply to all tensor in different kinds of container types.""" |
|
return apply_to_type(torch.is_tensor, fn, container) |
|
|
|
|
|
def to_np(tensor_or_container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: |
|
"""Convert a tensor or a container to numpy.""" |
|
return apply_to_type(torch.is_tensor, lambda x: x.cpu().numpy(), tensor_or_container) |
|
|
|
|
|
def from_np(ndarray_or_container: Union[np.ndarray, Dict, List, Tuple, Set]) -> Any: |
|
"""Convert a ndarray or a container to tensor.""" |
|
return apply_to_type(lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x), ndarray_or_container) |
|
|
|
|
|
def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]: |
|
""" |
|
Turn argument list into separate key list and value list (unpack_kwargs does the opposite) |
|
|
|
Usage:: |
|
|
|
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) |
|
assert kwarg_keys == ("a", "b") |
|
assert flat_args == (1, 2, 3, 4) |
|
args, kwargs = unpack_kwargs(kwarg_keys, flat_args) |
|
assert args == (1, 2) |
|
assert kwargs == {"a": 3, "b": 4} |
|
""" |
|
kwarg_keys: List[str] = [] |
|
flat_args: List[Any] = list(args) |
|
for k, v in kwargs.items(): |
|
kwarg_keys.append(k) |
|
flat_args.append(v) |
|
return tuple(kwarg_keys), tuple(flat_args) |
|
|
|
|
|
def unpack_kwargs(kwarg_keys: Tuple[str, ...], flat_args: Tuple[Any, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: |
|
"""See pack_kwargs.""" |
|
assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" |
|
if len(kwarg_keys) == 0: |
|
return flat_args, {} |
|
args = flat_args[: -len(kwarg_keys)] |
|
kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} |
|
return args, kwargs |
|
|
|
|
|
def split_non_tensors( |
|
mixed: Union[torch.Tensor, Tuple[Any, ...]] |
|
) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]: |
|
""" |
|
Split a tuple into a list of tensors and the rest with information |
|
for later reconstruction. |
|
|
|
When called with a tensor X, will return: (x,), None |
|
|
|
Usage:: |
|
|
|
x = torch.Tensor([1]) |
|
y = torch.Tensor([2]) |
|
tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) |
|
assert tensors == (x, y) |
|
assert packed_non_tensors == { |
|
"is_tensor": [True, True, False, False], |
|
"objects": [None, 3], |
|
} |
|
recon = unpack_non_tensors(tensors, packed_non_tensors) |
|
assert recon == (x, y, None, 3) |
|
""" |
|
if isinstance(mixed, torch.Tensor): |
|
return (mixed,), None |
|
tensors: List[torch.Tensor] = [] |
|
packed_non_tensors: Dict[str, List[Any]] = {"is_tensor": [], "objects": []} |
|
for o in mixed: |
|
if isinstance(o, torch.Tensor): |
|
packed_non_tensors["is_tensor"].append(True) |
|
tensors.append(o) |
|
else: |
|
packed_non_tensors["is_tensor"].append(False) |
|
packed_non_tensors["objects"].append(o) |
|
return tuple(tensors), packed_non_tensors |
|
|
|
|
|
def unpack_non_tensors( |
|
tensors: Tuple[torch.Tensor, ...], packed_non_tensors: Optional[Dict[str, List[Any]]] |
|
) -> Tuple[Any, ...]: |
|
"""See split_non_tensors.""" |
|
if packed_non_tensors is None: |
|
return tensors |
|
assert isinstance(packed_non_tensors, dict), type(packed_non_tensors) |
|
mixed: List[Any] = [] |
|
is_tensor_list = packed_non_tensors["is_tensor"] |
|
objects = packed_non_tensors["objects"] |
|
assert len(tensors) + len(objects) == len(is_tensor_list), ( |
|
f"len(tensors) {len(tensors)} len(objects) {len(objects)} " f"len(is_tensor_list) {len(is_tensor_list)}" |
|
) |
|
obj_i = tnsr_i = 0 |
|
for is_tensor in is_tensor_list: |
|
if is_tensor: |
|
mixed.append(tensors[tnsr_i]) |
|
tnsr_i += 1 |
|
else: |
|
mixed.append(objects[obj_i]) |
|
obj_i += 1 |
|
return tuple(mixed) |
|
|