|
""" |
|
Specialization of einops for torch. |
|
|
|
Unfortunately, torch's jit scripting mechanism isn't strong enough, |
|
and to have scripting supported at least for layers, |
|
a number of additional moves is needed. |
|
|
|
Design of main operations (dynamic resolution by lookup) is unlikely |
|
to be implemented by torch.jit.script, |
|
but torch.compile seems to work with operations just fine. |
|
""" |
|
|
|
import warnings |
|
from typing import Dict, List, Tuple |
|
|
|
import torch |
|
from einops.einops import TransformRecipe, _reconstruct_from_shape_uncached |
|
|
|
|
|
class TorchJitBackend: |
|
""" |
|
Completely static backend that mimics part of normal backend functionality |
|
but restricted to be within torchscript. |
|
""" |
|
|
|
@staticmethod |
|
def reduce(x: torch.Tensor, operation: str, reduced_axes: List[int]): |
|
if operation == "min": |
|
return x.amin(dim=reduced_axes) |
|
elif operation == "max": |
|
return x.amax(dim=reduced_axes) |
|
elif operation == "sum": |
|
return x.sum(dim=reduced_axes) |
|
elif operation == "mean": |
|
return x.mean(dim=reduced_axes) |
|
elif operation == "prod": |
|
for i in list(sorted(reduced_axes))[::-1]: |
|
x = x.prod(dim=i) |
|
return x |
|
else: |
|
raise NotImplementedError("Unknown reduction ", operation) |
|
|
|
@staticmethod |
|
def transpose(x, axes: List[int]): |
|
return x.permute(axes) |
|
|
|
@staticmethod |
|
def stack_on_zeroth_dimension(tensors: List[torch.Tensor]): |
|
return torch.stack(tensors) |
|
|
|
@staticmethod |
|
def tile(x, repeats: List[int]): |
|
return x.repeat(repeats) |
|
|
|
@staticmethod |
|
def add_axes(x, n_axes: int, pos2len: Dict[int, int]): |
|
repeats = [-1] * n_axes |
|
for axis_position, axis_length in pos2len.items(): |
|
x = torch.unsqueeze(x, axis_position) |
|
repeats[axis_position] = axis_length |
|
return x.expand(repeats) |
|
|
|
@staticmethod |
|
def is_float_type(x): |
|
return x.dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16] |
|
|
|
@staticmethod |
|
def shape(x): |
|
return x.shape |
|
|
|
@staticmethod |
|
def reshape(x, shape: List[int]): |
|
return x.reshape(shape) |
|
|
|
|
|
|
|
def apply_for_scriptable_torch( |
|
recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str, axes_dims: List[Tuple[str, int]] |
|
) -> torch.Tensor: |
|
backend = TorchJitBackend |
|
( |
|
init_shapes, |
|
axes_reordering, |
|
reduced_axes, |
|
added_axes, |
|
final_shapes, |
|
n_axes_w_added, |
|
) = _reconstruct_from_shape_uncached(recipe, backend.shape(tensor), axes_dims=axes_dims) |
|
if init_shapes is not None: |
|
tensor = backend.reshape(tensor, init_shapes) |
|
if axes_reordering is not None: |
|
tensor = backend.transpose(tensor, axes_reordering) |
|
if len(reduced_axes) > 0: |
|
tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes) |
|
if len(added_axes) > 0: |
|
tensor = backend.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes) |
|
if final_shapes is not None: |
|
tensor = backend.reshape(tensor, final_shapes) |
|
return tensor |
|
|
|
|
|
def allow_ops_in_compiled_graph(): |
|
if hasattr(torch, "__version__") and torch.__version__[0] < "2": |
|
|
|
return |
|
try: |
|
from torch._dynamo import allow_in_graph |
|
except ImportError: |
|
warnings.warn("allow_ops_in_compiled_graph failed to import torch: ensure pytorch >=2.0", ImportWarning) |
|
return |
|
|
|
from .einops import rearrange, reduce, repeat, einsum |
|
from .packing import pack, unpack |
|
|
|
allow_in_graph(rearrange) |
|
allow_in_graph(reduce) |
|
allow_in_graph(repeat) |
|
allow_in_graph(einsum) |
|
allow_in_graph(pack) |
|
allow_in_graph(unpack) |
|
|
|
|
|
global _ops_were_registered_in_torchdynamo |
|
_ops_were_registered_in_torchdynamo = True |
|
|
|
|
|
|
|
allow_ops_in_compiled_graph() |
|
|