File size: 11,125 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from typing import Any, List, Optional, Dict

from einops import EinopsError
from einops.parsing import ParsedExpression, _ellipsis
import warnings
import string
from ..einops import _product


def _report_axes(axes: set, report_message: str):
    if len(axes) > 0:
        raise EinopsError(report_message.format(axes))


class _EinmixMixin:
    def __init__(self, pattern: str, weight_shape: str, bias_shape: Optional[str] = None, **axes_lengths: Any):
        """
        EinMix - Einstein summation with automated tensor management and axis packing/unpacking.

        EinMix is a combination of einops and MLP, see tutorial:
        https://github.com/arogozhnikov/einops/blob/main/docs/3-einmix-layer.ipynb

        Imagine taking einsum with two arguments, one of each input, and one - tensor with weights
        >>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight)

        This layer manages weights for you, syntax highlights a special role of weight matrix
        >>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out')
        But otherwise it is the same einsum under the hood. Plus einops-rearrange.

        Simple linear layer with a bias term (you have one like that in your framework)
        >>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20)
        There is no restriction to mix the last axis. Let's mix along height
        >>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32)
        Example of channel-wise multiplication (like one used in normalizations)
        >>> EinMix('t b c -> t b c', weight_shape='c', c=128)
        Multi-head linear layer (each head is own linear layer):
        >>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...)

        ... and yes, you need to specify all dimensions of weight shape/bias shape in parameters.

        Use cases:
        - when channel dimension is not last, use EinMix, not transposition
        - patch/segment embeddings
        - when need only within-group connections to reduce number of weights and computations
        - next-gen MLPs (follow tutorial link above to learn more!)
        - in general, any time you want to combine linear layer and einops.rearrange

        Uniform He initialization is applied to weight tensor.
        This accounts for the number of elements mixed and produced.

        Parameters
        :param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output
        :param weight_shape: axes of weight. A tensor of this shape is created, stored, and optimized in a layer
               If bias_shape is not specified, bias is not created.
        :param bias_shape: axes of bias added to output. Weights of this shape are created and stored. If `None` (the default), no bias is added.
        :param axes_lengths: dimensions of weight tensor
        """
        super().__init__()
        self.pattern = pattern
        self.weight_shape = weight_shape
        self.bias_shape = bias_shape
        self.axes_lengths = axes_lengths
        self.initialize_einmix(
            pattern=pattern, weight_shape=weight_shape, bias_shape=bias_shape, axes_lengths=axes_lengths
        )

    def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optional[str], axes_lengths: dict):
        left_pattern, right_pattern = pattern.split("->")
        left = ParsedExpression(left_pattern)
        right = ParsedExpression(right_pattern)
        weight = ParsedExpression(weight_shape)
        _report_axes(
            set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}),
            "Unrecognized identifiers on the right side of EinMix {}",
        )
        if weight.has_ellipsis:
            raise EinopsError("Ellipsis is not supported in weight, as its shape should be fully specified")
        if left.has_ellipsis or right.has_ellipsis:
            if not (left.has_ellipsis and right.has_ellipsis):
                raise EinopsError(f"Ellipsis in EinMix should be on both sides, {pattern}")
            if left.has_ellipsis_parenthesized:
                raise EinopsError(f"Ellipsis on left side can't be in parenthesis, got {pattern}")
        if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]):
            raise EinopsError("Anonymous axes (numbers) are not allowed in EinMix")
        if "(" in weight_shape or ")" in weight_shape:
            raise EinopsError(f"Parenthesis is not allowed in weight shape: {weight_shape}")

        pre_reshape_pattern = None
        pre_reshape_lengths = None
        post_reshape_pattern = None
        if any(len(group) != 1 for group in left.composition):
            names: List[str] = []
            for group in left.composition:
                names += group
            names = [name if name != _ellipsis else "..." for name in names]
            composition = " ".join(names)
            pre_reshape_pattern = f"{left_pattern}-> {composition}"
            pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names}

        if any(len(group) != 1 for group in right.composition) or right.has_ellipsis_parenthesized:
            names = []
            for group in right.composition:
                names += group
            names = [name if name != _ellipsis else "..." for name in names]
            composition = " ".join(names)
            post_reshape_pattern = f"{composition} ->{right_pattern}"

        self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {})

        for axis in weight.identifiers:
            if axis not in axes_lengths:
                raise EinopsError("Dimension {} of weight should be specified".format(axis))
        _report_axes(
            set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}),
            "Axes {} are not used in pattern",
        )
        _report_axes(
            set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}), "Weight axes {} are redundant"
        )
        if len(weight.identifiers) == 0:
            warnings.warn("EinMix: weight has no dimensions (means multiplication by a number)")

        _weight_shape = [axes_lengths[axis] for (axis,) in weight.composition]
        # single output element is a combination of fan_in input elements
        _fan_in = _product([axes_lengths[axis] for (axis,) in weight.composition if axis not in right.identifiers])
        if bias_shape is not None:
            # maybe I should put ellipsis in the beginning for simplicity?
            if not isinstance(bias_shape, str):
                raise EinopsError("bias shape should be string specifying which axes bias depends on")
            bias = ParsedExpression(bias_shape)
            _report_axes(
                set.difference(bias.identifiers, right.identifiers),
                "Bias axes {} not present in output",
            )
            _report_axes(
                set.difference(bias.identifiers, set(axes_lengths)),
                "Sizes not provided for bias axes {}",
            )

            _bias_shape = []
            used_non_trivial_size = False
            for axes in right.composition:
                if axes == _ellipsis:
                    if used_non_trivial_size:
                        raise EinopsError("all bias dimensions should go after ellipsis in the output")
                else:
                    # handles ellipsis correctly
                    for axis in axes:
                        if axis == _ellipsis:
                            if used_non_trivial_size:
                                raise EinopsError("all bias dimensions should go after ellipsis in the output")
                        elif axis in bias.identifiers:
                            _bias_shape.append(axes_lengths[axis])
                            used_non_trivial_size = True
                        else:
                            _bias_shape.append(1)
        else:
            _bias_shape = None

        weight_bound = (3 / _fan_in) ** 0.5
        bias_bound = (1 / _fan_in) ** 0.5
        self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound)

        # rewrite einsum expression with single-letter latin identifiers so that
        # expression will be understood by any framework
        mapped_identifiers = {*left.identifiers, *right.identifiers, *weight.identifiers}
        if _ellipsis in mapped_identifiers:
            mapped_identifiers.remove(_ellipsis)
        mapped_identifiers = list(sorted(mapped_identifiers))
        mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapped_identifiers)}
        mapping2letters[_ellipsis] = "..."  # preserve ellipsis

        def write_flat_remapped(axes: ParsedExpression):
            result = []
            for composed_axis in axes.composition:
                if isinstance(composed_axis, list):
                    result.extend([mapping2letters[axis] for axis in composed_axis])
                else:
                    assert composed_axis == _ellipsis
                    result.append("...")
            return "".join(result)

        self.einsum_pattern: str = "{},{}->{}".format(
            write_flat_remapped(left),
            write_flat_remapped(weight),
            write_flat_remapped(right),
        )

    def _create_rearrange_layers(
        self,
        pre_reshape_pattern: Optional[str],
        pre_reshape_lengths: Optional[Dict],
        post_reshape_pattern: Optional[str],
        post_reshape_lengths: Optional[Dict],
    ):
        raise NotImplementedError("Should be defined in framework implementations")

    def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
        """Shape and implementations"""
        raise NotImplementedError("Should be defined in framework implementations")

    def __repr__(self):
        params = repr(self.pattern)
        params += f", '{self.weight_shape}'"
        if self.bias_shape is not None:
            params += f", '{self.bias_shape}'"
        for axis, length in self.axes_lengths.items():
            params += ", {}={}".format(axis, length)
        return "{}({})".format(self.__class__.__name__, params)


class _EinmixDebugger(_EinmixMixin):
    """Used only to test mixin"""

    def _create_rearrange_layers(
        self,
        pre_reshape_pattern: Optional[str],
        pre_reshape_lengths: Optional[Dict],
        post_reshape_pattern: Optional[str],
        post_reshape_lengths: Optional[Dict],
    ):
        self.pre_reshape_pattern = pre_reshape_pattern
        self.pre_reshape_lengths = pre_reshape_lengths
        self.post_reshape_pattern = post_reshape_pattern
        self.post_reshape_lengths = post_reshape_lengths

    def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
        self.saved_weight_shape = weight_shape
        self.saved_bias_shape = bias_shape