File size: 7,650 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
from functools import lru_cache
from typing import List, Union, TypeVar, Tuple, Sequence
from einops import EinopsError
from einops._backends import get_backend
from einops.parsing import ParsedExpression
Tensor = TypeVar("Tensor")
Shape = Union[Tuple[int, ...], List[int]]
@lru_cache(maxsize=128)
def analyze_pattern(pattern: str, opname: str) -> Tuple[int, int, int]:
# Maybe some validation of identifiers?
axes = pattern.split()
axes_set = set(axes)
if len(axes) != len(axes_set):
raise EinopsError(f'Duplicates in axes names in {opname}(..., "{pattern}")')
if "*" not in axes_set:
raise EinopsError(f'No *-axis in {opname}(..., "{pattern}")')
for axis in axes:
if axis != "*":
is_valid, reason = ParsedExpression.check_axis_name_return_reason(axis)
if not is_valid:
raise EinopsError(f'Invalid axis name {axis} in {opname}(..., "{pattern}")')
n_axes_before = axes.index("*")
n_axes_after = len(axes) - n_axes_before - 1
min_axes = n_axes_before + n_axes_after
return n_axes_before, n_axes_after, min_axes
def pack(tensors: Sequence[Tensor], pattern: str) -> Tuple[Tensor, List[Shape]]:
"""
Packs several tensors into one.
See einops tutorial for introduction into packing (and how it replaces stack and concatenation).
Parameters:
tensors: tensors to be packed, can be of different dimensionality
pattern: pattern that is shared for all inputs and output, e.g. "i j * k" or "batch seq *"
Returns:
(packed_tensor, packed_shapes aka PS)
Example:
```python
>>> from numpy import zeros as Z
>>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
>>> packed, ps = pack(inputs, 'i j * k')
>>> packed.shape, ps
((2, 3, 71, 5), [(), (7,), (7, 9)])
```
In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last).
All other axes were 'packed' and concatenated.
PS (packed shapes) contains information about axes that were matched to '*' in every input.
Resulting tensor has as many elements as all inputs in total.
Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.
```python
>>> inputs_unpacked = unpack(packed, ps, 'i j * k')
>>> [x.shape for x in inputs_unpacked]
[(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]
```
Read the tutorial for introduction and application scenarios.
"""
n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, "pack")
# packing zero tensors is illegal
backend = get_backend(tensors[0])
reshaped_tensors: List[Tensor] = []
packed_shapes: List[Shape] = []
for i, tensor in enumerate(tensors):
shape = backend.shape(tensor)
if len(shape) < min_axes:
raise EinopsError(
f"packed tensor #{i} (enumeration starts with 0) has shape {shape}, "
f"while pattern {pattern} assumes at least {min_axes} axes"
)
axis_after_packed_axes = len(shape) - n_axes_after
packed_shapes.append(shape[n_axes_before:axis_after_packed_axes])
reshaped_tensors.append(backend.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])))
return backend.concat(reshaped_tensors, axis=n_axes_before), packed_shapes
def prod(x: Shape) -> int:
result = 1
for i in x:
result *= i
return result
def unpack(tensor: Tensor, packed_shapes: List[Shape], pattern: str) -> List[Tensor]:
"""
Unpacks a single tensor into several by splitting over a selected axes.
See einops tutorial for introduction into packing (and how it replaces stack and concatenation).
Parameters:
tensor: tensor to be unpacked
packed_shapes: packed_shapes (aka PS) is a list of shapes that take place of '*' in each output.
output will contain a single tensor for every provided shape
pattern: pattern that is shared for input and all outputs, e.g. "i j * k" or "batch seq *",
where * designates an axis to be unpacked
Returns:
list of tensors
If framework supports views, results are views to the original tensor.
Example:
```python
>>> from numpy import zeros as Z
>>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
>>> packed, ps = pack(inputs, 'i j * k')
>>> packed.shape, ps
((2, 3, 71, 5), [(), (7,), (7, 9)])
```
In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last).
All other axes were 'packed' and concatenated.
PS (packed shapes) contains information about axes that were matched to '*' in every input.
Resulting tensor has as many elements as all inputs in total.
Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.
```python
>>> inputs_unpacked = unpack(packed, ps, 'i j * k')
>>> [x.shape for x in inputs_unpacked]
[(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]
```
Read the tutorial for introduction and application scenarios.
"""
n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, opname="unpack")
backend = get_backend(tensor)
input_shape = backend.shape(tensor)
if len(input_shape) != n_axes_before + 1 + n_axes_after:
raise EinopsError(f"unpack(..., {pattern}) received input of wrong dim with shape {input_shape}")
unpacked_axis: int = n_axes_before
lengths_of_composed_axes: List[int] = [-1 if -1 in p_shape else prod(p_shape) for p_shape in packed_shapes]
n_unknown_composed_axes = sum(int(x == -1) for x in lengths_of_composed_axes)
if n_unknown_composed_axes > 1:
raise EinopsError(
f"unpack(..., {pattern}) received more than one -1 in {packed_shapes} and can't infer dimensions"
)
# following manipulations allow to skip some shape verifications
# and leave it to backends
# [[], [2, 3], [4], [-1, 5], [6]] < examples of packed_axis
# split positions when computed should be
# [0, 1, 7, 11, N-6 , N ], where N = length of axis
split_positions = [0] * len(packed_shapes) + [input_shape[unpacked_axis]]
if n_unknown_composed_axes == 0:
for i, x in enumerate(lengths_of_composed_axes[:-1]):
split_positions[i + 1] = split_positions[i] + x
else:
unknown_composed_axis: int = lengths_of_composed_axes.index(-1)
for i in range(unknown_composed_axis):
split_positions[i + 1] = split_positions[i] + lengths_of_composed_axes[i]
for j in range(unknown_composed_axis + 1, len(lengths_of_composed_axes))[::-1]:
split_positions[j] = split_positions[j + 1] - lengths_of_composed_axes[j]
shape_start = input_shape[:unpacked_axis]
shape_end = input_shape[unpacked_axis + 1 :]
slice_filler = (slice(None, None),) * unpacked_axis
try:
return [
backend.reshape(
# shortest way slice arbitrary axis
tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]))],
(*shape_start, *element_shape, *shape_end),
)
for i, element_shape in enumerate(packed_shapes)
]
except Exception:
# this hits if there is an error during reshapes, which means passed shapes were incorrect
raise RuntimeError(
f'Error during unpack(..., "{pattern}"): could not split axis of size {split_positions[-1]}'
f" into requested {packed_shapes}"
)
|