jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union, cast
import numpy as np
import torch
from torch.nn.utils.rnn import PackedSequence
"""Useful functions to deal with tensor types with other python container types."""
def apply_to_type(
type_fn: Callable, fn: Callable, container: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set, NamedTuple]
) -> Any:
"""Recursively apply to all objects in different kinds of container types that matches a type function."""
def _apply(x: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set]) -> Any:
if type_fn(x):
return fn(x)
elif isinstance(x, OrderedDict):
od = x.__class__()
for key, value in x.items():
od[key] = _apply(value)
return od
elif isinstance(x, PackedSequence):
_apply(x.data)
return x
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
f = getattr(x, "_fields", None)
if f is None:
return tuple(_apply(x) for x in x)
else:
assert isinstance(f, tuple), "This needs to be a namedtuple"
# convert the namedtuple to a dict and _apply().
x = cast(NamedTuple, x)
_dict: Dict[str, Any] = x._asdict()
_dict = {key: _apply(value) for key, value in _dict.items()}
return type(x)(**_dict) # make a copy of the namedtuple
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
return x
return _apply(container)
def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
return apply_to_type(torch.is_tensor, fn, container)
def to_np(tensor_or_container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
"""Convert a tensor or a container to numpy."""
return apply_to_type(torch.is_tensor, lambda x: x.cpu().numpy(), tensor_or_container)
def from_np(ndarray_or_container: Union[np.ndarray, Dict, List, Tuple, Set]) -> Any:
"""Convert a ndarray or a container to tensor."""
return apply_to_type(lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x), ndarray_or_container)
def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]:
"""
Turn argument list into separate key list and value list (unpack_kwargs does the opposite)
Usage::
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
assert kwarg_keys == ("a", "b")
assert flat_args == (1, 2, 3, 4)
args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
assert args == (1, 2)
assert kwargs == {"a": 3, "b": 4}
"""
kwarg_keys: List[str] = []
flat_args: List[Any] = list(args)
for k, v in kwargs.items():
kwarg_keys.append(k)
flat_args.append(v)
return tuple(kwarg_keys), tuple(flat_args)
def unpack_kwargs(kwarg_keys: Tuple[str, ...], flat_args: Tuple[Any, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""See pack_kwargs."""
assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
if len(kwarg_keys) == 0:
return flat_args, {}
args = flat_args[: -len(kwarg_keys)]
kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])}
return args, kwargs
def split_non_tensors(
mixed: Union[torch.Tensor, Tuple[Any, ...]]
) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]:
"""
Split a tuple into a list of tensors and the rest with information
for later reconstruction.
When called with a tensor X, will return: (x,), None
Usage::
x = torch.Tensor([1])
y = torch.Tensor([2])
tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
assert tensors == (x, y)
assert packed_non_tensors == {
"is_tensor": [True, True, False, False],
"objects": [None, 3],
}
recon = unpack_non_tensors(tensors, packed_non_tensors)
assert recon == (x, y, None, 3)
"""
if isinstance(mixed, torch.Tensor):
return (mixed,), None
tensors: List[torch.Tensor] = []
packed_non_tensors: Dict[str, List[Any]] = {"is_tensor": [], "objects": []}
for o in mixed:
if isinstance(o, torch.Tensor):
packed_non_tensors["is_tensor"].append(True)
tensors.append(o)
else:
packed_non_tensors["is_tensor"].append(False)
packed_non_tensors["objects"].append(o)
return tuple(tensors), packed_non_tensors
def unpack_non_tensors(
tensors: Tuple[torch.Tensor, ...], packed_non_tensors: Optional[Dict[str, List[Any]]]
) -> Tuple[Any, ...]:
"""See split_non_tensors."""
if packed_non_tensors is None:
return tensors
assert isinstance(packed_non_tensors, dict), type(packed_non_tensors)
mixed: List[Any] = []
is_tensor_list = packed_non_tensors["is_tensor"]
objects = packed_non_tensors["objects"]
assert len(tensors) + len(objects) == len(is_tensor_list), (
f"len(tensors) {len(tensors)} len(objects) {len(objects)} " f"len(is_tensor_list) {len(is_tensor_list)}"
)
obj_i = tnsr_i = 0
for is_tensor in is_tensor_list:
if is_tensor:
mixed.append(tensors[tnsr_i])
tnsr_i += 1
else:
mixed.append(objects[obj_i])
obj_i += 1
return tuple(mixed)