File size: 5,247 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import List, Tuple, Sequence
from .einops import Tensor, Reduction, EinopsError, _prepare_transformation_recipe, _apply_recipe_array_api
from .packing import analyze_pattern, prod


def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor:
    if isinstance(tensor, list):
        if len(tensor) == 0:
            raise TypeError("Einops can't be applied to an empty list")
        xp = tensor[0].__array_namespace__()
        tensor = xp.stack(tensor)
    else:
        xp = tensor.__array_namespace__()
    try:
        hashable_axes_lengths = tuple(axes_lengths.items())
        recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=tensor.ndim)
        return _apply_recipe_array_api(
            xp,
            recipe=recipe,
            tensor=tensor,
            reduction_type=reduction,
            axes_lengths=hashable_axes_lengths,
        )
    except EinopsError as e:
        message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern)
        if not isinstance(tensor, list):
            message += "\n Input tensor shape: {}. ".format(tensor.shape)
        else:
            message += "\n Input is list. "
        message += "Additional info: {}.".format(axes_lengths)
        raise EinopsError(message + "\n {}".format(e))


def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
    return reduce(tensor, pattern, reduction="repeat", **axes_lengths)


def rearrange(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
    return reduce(tensor, pattern, reduction="rearrange", **axes_lengths)


def asnumpy(tensor: Tensor):
    import numpy as np

    return np.from_dlpack(tensor)


Shape = Tuple


def pack(tensors: Sequence[Tensor], pattern: str) -> Tuple[Tensor, List[Shape]]:
    n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, "pack")
    xp = tensors[0].__array_namespace__()

    reshaped_tensors: List[Tensor] = []
    packed_shapes: List[Shape] = []
    for i, tensor in enumerate(tensors):
        shape = tensor.shape
        if len(shape) < min_axes:
            raise EinopsError(
                f"packed tensor #{i} (enumeration starts with 0) has shape {shape}, "
                f"while pattern {pattern} assumes at least {min_axes} axes"
            )
        axis_after_packed_axes = len(shape) - n_axes_after
        packed_shapes.append(shape[n_axes_before:axis_after_packed_axes])
        reshaped_tensors.append(xp.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])))

    return xp.concat(reshaped_tensors, axis=n_axes_before), packed_shapes


def unpack(tensor: Tensor, packed_shapes: List[Shape], pattern: str) -> List[Tensor]:
    xp = tensor.__array_namespace__()
    n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, opname="unpack")

    # backend = get_backend(tensor)
    input_shape = tensor.shape
    if len(input_shape) != n_axes_before + 1 + n_axes_after:
        raise EinopsError(f"unpack(..., {pattern}) received input of wrong dim with shape {input_shape}")

    unpacked_axis: int = n_axes_before

    lengths_of_composed_axes: List[int] = [-1 if -1 in p_shape else prod(p_shape) for p_shape in packed_shapes]

    n_unknown_composed_axes = sum(x == -1 for x in lengths_of_composed_axes)
    if n_unknown_composed_axes > 1:
        raise EinopsError(
            f"unpack(..., {pattern}) received more than one -1 in {packed_shapes} and can't infer dimensions"
        )

    # following manipulations allow to skip some shape verifications
    # and leave it to backends

    # [[], [2, 3], [4], [-1, 5], [6]] < examples of packed_axis
    # split positions when computed should be
    # [0,   1,      7,   11,      N-6 , N ], where N = length of axis
    split_positions = [0] * len(packed_shapes) + [input_shape[unpacked_axis]]
    if n_unknown_composed_axes == 0:
        for i, x in enumerate(lengths_of_composed_axes[:-1]):
            split_positions[i + 1] = split_positions[i] + x
    else:
        unknown_composed_axis: int = lengths_of_composed_axes.index(-1)
        for i in range(unknown_composed_axis):
            split_positions[i + 1] = split_positions[i] + lengths_of_composed_axes[i]
        for j in range(unknown_composed_axis + 1, len(lengths_of_composed_axes))[::-1]:
            split_positions[j] = split_positions[j + 1] - lengths_of_composed_axes[j]

    shape_start = input_shape[:unpacked_axis]
    shape_end = input_shape[unpacked_axis + 1 :]
    slice_filler = (slice(None, None),) * unpacked_axis
    try:
        return [
            xp.reshape(
                # shortest way slice arbitrary axis
                tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]), ...)],
                (*shape_start, *element_shape, *shape_end),
            )
            for i, element_shape in enumerate(packed_shapes)
        ]
    except Exception:
        # this hits if there is an error during reshapes, which means passed shapes were incorrect
        raise RuntimeError(
            f'Error during unpack(..., "{pattern}"): could not split axis of size {split_positions[-1]}'
            f" into requested {packed_shapes}"
        )