|
from typing import Any, List, Optional, Dict |
|
|
|
from einops import EinopsError |
|
from einops.parsing import ParsedExpression, _ellipsis |
|
import warnings |
|
import string |
|
from ..einops import _product |
|
|
|
|
|
def _report_axes(axes: set, report_message: str): |
|
if len(axes) > 0: |
|
raise EinopsError(report_message.format(axes)) |
|
|
|
|
|
class _EinmixMixin: |
|
def __init__(self, pattern: str, weight_shape: str, bias_shape: Optional[str] = None, **axes_lengths: Any): |
|
""" |
|
EinMix - Einstein summation with automated tensor management and axis packing/unpacking. |
|
|
|
EinMix is a combination of einops and MLP, see tutorial: |
|
https://github.com/arogozhnikov/einops/blob/main/docs/3-einmix-layer.ipynb |
|
|
|
Imagine taking einsum with two arguments, one of each input, and one - tensor with weights |
|
>>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight) |
|
|
|
This layer manages weights for you, syntax highlights a special role of weight matrix |
|
>>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out') |
|
But otherwise it is the same einsum under the hood. Plus einops-rearrange. |
|
|
|
Simple linear layer with a bias term (you have one like that in your framework) |
|
>>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20) |
|
There is no restriction to mix the last axis. Let's mix along height |
|
>>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32) |
|
Example of channel-wise multiplication (like one used in normalizations) |
|
>>> EinMix('t b c -> t b c', weight_shape='c', c=128) |
|
Multi-head linear layer (each head is own linear layer): |
|
>>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...) |
|
|
|
... and yes, you need to specify all dimensions of weight shape/bias shape in parameters. |
|
|
|
Use cases: |
|
- when channel dimension is not last, use EinMix, not transposition |
|
- patch/segment embeddings |
|
- when need only within-group connections to reduce number of weights and computations |
|
- next-gen MLPs (follow tutorial link above to learn more!) |
|
- in general, any time you want to combine linear layer and einops.rearrange |
|
|
|
Uniform He initialization is applied to weight tensor. |
|
This accounts for the number of elements mixed and produced. |
|
|
|
Parameters |
|
:param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output |
|
:param weight_shape: axes of weight. A tensor of this shape is created, stored, and optimized in a layer |
|
If bias_shape is not specified, bias is not created. |
|
:param bias_shape: axes of bias added to output. Weights of this shape are created and stored. If `None` (the default), no bias is added. |
|
:param axes_lengths: dimensions of weight tensor |
|
""" |
|
super().__init__() |
|
self.pattern = pattern |
|
self.weight_shape = weight_shape |
|
self.bias_shape = bias_shape |
|
self.axes_lengths = axes_lengths |
|
self.initialize_einmix( |
|
pattern=pattern, weight_shape=weight_shape, bias_shape=bias_shape, axes_lengths=axes_lengths |
|
) |
|
|
|
def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optional[str], axes_lengths: dict): |
|
left_pattern, right_pattern = pattern.split("->") |
|
left = ParsedExpression(left_pattern) |
|
right = ParsedExpression(right_pattern) |
|
weight = ParsedExpression(weight_shape) |
|
_report_axes( |
|
set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}), |
|
"Unrecognized identifiers on the right side of EinMix {}", |
|
) |
|
if weight.has_ellipsis: |
|
raise EinopsError("Ellipsis is not supported in weight, as its shape should be fully specified") |
|
if left.has_ellipsis or right.has_ellipsis: |
|
if not (left.has_ellipsis and right.has_ellipsis): |
|
raise EinopsError(f"Ellipsis in EinMix should be on both sides, {pattern}") |
|
if left.has_ellipsis_parenthesized: |
|
raise EinopsError(f"Ellipsis on left side can't be in parenthesis, got {pattern}") |
|
if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]): |
|
raise EinopsError("Anonymous axes (numbers) are not allowed in EinMix") |
|
if "(" in weight_shape or ")" in weight_shape: |
|
raise EinopsError(f"Parenthesis is not allowed in weight shape: {weight_shape}") |
|
|
|
pre_reshape_pattern = None |
|
pre_reshape_lengths = None |
|
post_reshape_pattern = None |
|
if any(len(group) != 1 for group in left.composition): |
|
names: List[str] = [] |
|
for group in left.composition: |
|
names += group |
|
names = [name if name != _ellipsis else "..." for name in names] |
|
composition = " ".join(names) |
|
pre_reshape_pattern = f"{left_pattern}-> {composition}" |
|
pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names} |
|
|
|
if any(len(group) != 1 for group in right.composition) or right.has_ellipsis_parenthesized: |
|
names = [] |
|
for group in right.composition: |
|
names += group |
|
names = [name if name != _ellipsis else "..." for name in names] |
|
composition = " ".join(names) |
|
post_reshape_pattern = f"{composition} ->{right_pattern}" |
|
|
|
self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {}) |
|
|
|
for axis in weight.identifiers: |
|
if axis not in axes_lengths: |
|
raise EinopsError("Dimension {} of weight should be specified".format(axis)) |
|
_report_axes( |
|
set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}), |
|
"Axes {} are not used in pattern", |
|
) |
|
_report_axes( |
|
set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}), "Weight axes {} are redundant" |
|
) |
|
if len(weight.identifiers) == 0: |
|
warnings.warn("EinMix: weight has no dimensions (means multiplication by a number)") |
|
|
|
_weight_shape = [axes_lengths[axis] for (axis,) in weight.composition] |
|
|
|
_fan_in = _product([axes_lengths[axis] for (axis,) in weight.composition if axis not in right.identifiers]) |
|
if bias_shape is not None: |
|
|
|
if not isinstance(bias_shape, str): |
|
raise EinopsError("bias shape should be string specifying which axes bias depends on") |
|
bias = ParsedExpression(bias_shape) |
|
_report_axes( |
|
set.difference(bias.identifiers, right.identifiers), |
|
"Bias axes {} not present in output", |
|
) |
|
_report_axes( |
|
set.difference(bias.identifiers, set(axes_lengths)), |
|
"Sizes not provided for bias axes {}", |
|
) |
|
|
|
_bias_shape = [] |
|
used_non_trivial_size = False |
|
for axes in right.composition: |
|
if axes == _ellipsis: |
|
if used_non_trivial_size: |
|
raise EinopsError("all bias dimensions should go after ellipsis in the output") |
|
else: |
|
|
|
for axis in axes: |
|
if axis == _ellipsis: |
|
if used_non_trivial_size: |
|
raise EinopsError("all bias dimensions should go after ellipsis in the output") |
|
elif axis in bias.identifiers: |
|
_bias_shape.append(axes_lengths[axis]) |
|
used_non_trivial_size = True |
|
else: |
|
_bias_shape.append(1) |
|
else: |
|
_bias_shape = None |
|
|
|
weight_bound = (3 / _fan_in) ** 0.5 |
|
bias_bound = (1 / _fan_in) ** 0.5 |
|
self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound) |
|
|
|
|
|
|
|
mapped_identifiers = {*left.identifiers, *right.identifiers, *weight.identifiers} |
|
if _ellipsis in mapped_identifiers: |
|
mapped_identifiers.remove(_ellipsis) |
|
mapped_identifiers = list(sorted(mapped_identifiers)) |
|
mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapped_identifiers)} |
|
mapping2letters[_ellipsis] = "..." |
|
|
|
def write_flat_remapped(axes: ParsedExpression): |
|
result = [] |
|
for composed_axis in axes.composition: |
|
if isinstance(composed_axis, list): |
|
result.extend([mapping2letters[axis] for axis in composed_axis]) |
|
else: |
|
assert composed_axis == _ellipsis |
|
result.append("...") |
|
return "".join(result) |
|
|
|
self.einsum_pattern: str = "{},{}->{}".format( |
|
write_flat_remapped(left), |
|
write_flat_remapped(weight), |
|
write_flat_remapped(right), |
|
) |
|
|
|
def _create_rearrange_layers( |
|
self, |
|
pre_reshape_pattern: Optional[str], |
|
pre_reshape_lengths: Optional[Dict], |
|
post_reshape_pattern: Optional[str], |
|
post_reshape_lengths: Optional[Dict], |
|
): |
|
raise NotImplementedError("Should be defined in framework implementations") |
|
|
|
def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): |
|
"""Shape and implementations""" |
|
raise NotImplementedError("Should be defined in framework implementations") |
|
|
|
def __repr__(self): |
|
params = repr(self.pattern) |
|
params += f", '{self.weight_shape}'" |
|
if self.bias_shape is not None: |
|
params += f", '{self.bias_shape}'" |
|
for axis, length in self.axes_lengths.items(): |
|
params += ", {}={}".format(axis, length) |
|
return "{}({})".format(self.__class__.__name__, params) |
|
|
|
|
|
class _EinmixDebugger(_EinmixMixin): |
|
"""Used only to test mixin""" |
|
|
|
def _create_rearrange_layers( |
|
self, |
|
pre_reshape_pattern: Optional[str], |
|
pre_reshape_lengths: Optional[Dict], |
|
post_reshape_pattern: Optional[str], |
|
post_reshape_lengths: Optional[Dict], |
|
): |
|
self.pre_reshape_pattern = pre_reshape_pattern |
|
self.pre_reshape_lengths = pre_reshape_lengths |
|
self.post_reshape_pattern = post_reshape_pattern |
|
self.post_reshape_lengths = post_reshape_lengths |
|
|
|
def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): |
|
self.saved_weight_shape = weight_shape |
|
self.saved_bias_shape = bias_shape |
|
|