|
""" |
|
Backends in `einops` are organized to meet the following requirements |
|
- backends are not imported unless those are actually needed, because |
|
- backends may not be installed |
|
- importing all available backends will drive to significant memory footprint |
|
- backends may be present but installed with errors (but never used), |
|
importing may drive to crashes |
|
- backend should be either symbolic or imperative |
|
- this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined |
|
- if backend can't provide symbols for shape dimensions, UnknownSize objects are used |
|
""" |
|
|
|
import sys |
|
|
|
__author__ = "Alex Rogozhnikov" |
|
|
|
_loaded_backends: dict = {} |
|
_type2backend: dict = {} |
|
_debug_importing = False |
|
|
|
|
|
def get_backend(tensor) -> "AbstractBackend": |
|
""" |
|
Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor. |
|
If needed, imports package and creates backend |
|
""" |
|
_type = type(tensor) |
|
_result = _type2backend.get(_type, None) |
|
if _result is not None: |
|
return _result |
|
|
|
for framework_name, backend in list(_loaded_backends.items()): |
|
if backend.is_appropriate_type(tensor): |
|
_type2backend[_type] = backend |
|
return backend |
|
|
|
|
|
backend_subclasses = [] |
|
backends = AbstractBackend.__subclasses__() |
|
while backends: |
|
backend = backends.pop() |
|
backends += backend.__subclasses__() |
|
backend_subclasses.append(backend) |
|
|
|
for BackendSubclass in backend_subclasses: |
|
if _debug_importing: |
|
print("Testing for subclass of ", BackendSubclass) |
|
if BackendSubclass.framework_name not in _loaded_backends: |
|
|
|
if BackendSubclass.framework_name in sys.modules: |
|
if _debug_importing: |
|
print("Imported backend for ", BackendSubclass.framework_name) |
|
backend = BackendSubclass() |
|
_loaded_backends[backend.framework_name] = backend |
|
if backend.is_appropriate_type(tensor): |
|
_type2backend[_type] = backend |
|
return backend |
|
|
|
raise RuntimeError("Tensor type unknown to einops {}".format(type(tensor))) |
|
|
|
|
|
class AbstractBackend: |
|
"""Base backend class, major part of methods are only for debugging purposes.""" |
|
|
|
framework_name: str |
|
|
|
def is_appropriate_type(self, tensor): |
|
"""helper method should recognize tensors it can handle""" |
|
raise NotImplementedError() |
|
|
|
def from_numpy(self, x): |
|
raise NotImplementedError("framework doesn't support imperative execution") |
|
|
|
def to_numpy(self, x): |
|
raise NotImplementedError("framework doesn't support imperative execution") |
|
|
|
def create_symbol(self, shape): |
|
raise NotImplementedError("framework doesn't support symbolic computations") |
|
|
|
def eval_symbol(self, symbol, symbol_value_pairs): |
|
|
|
raise NotImplementedError("framework doesn't support symbolic computations") |
|
|
|
def arange(self, start, stop): |
|
|
|
raise NotImplementedError("framework doesn't implement arange") |
|
|
|
def shape(self, x): |
|
"""shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)""" |
|
return x.shape |
|
|
|
def reshape(self, x, shape): |
|
return x.reshape(shape) |
|
|
|
def transpose(self, x, axes): |
|
return x.transpose(axes) |
|
|
|
def reduce(self, x, operation, axes): |
|
return getattr(x, operation)(axis=axes) |
|
|
|
def stack_on_zeroth_dimension(self, tensors: list): |
|
raise NotImplementedError() |
|
|
|
def add_axis(self, x, new_position): |
|
raise NotImplementedError() |
|
|
|
def add_axes(self, x, n_axes, pos2len): |
|
repeats = [1] * n_axes |
|
for axis_position, axis_length in pos2len.items(): |
|
x = self.add_axis(x, axis_position) |
|
repeats[axis_position] = axis_length |
|
return self.tile(x, tuple(repeats)) |
|
|
|
def tile(self, x, repeats): |
|
"""repeats - same lengths as x.shape""" |
|
raise NotImplementedError() |
|
|
|
def concat(self, tensors, axis: int): |
|
"""concatenates tensors along axis. |
|
Assume identical across tensors: devices, dtypes and shapes except selected axis.""" |
|
raise NotImplementedError() |
|
|
|
def is_float_type(self, x): |
|
|
|
|
|
raise NotImplementedError() |
|
|
|
def layers(self): |
|
raise NotImplementedError("backend does not provide layers") |
|
|
|
def __repr__(self): |
|
return "<einops backend for {}>".format(self.framework_name) |
|
|
|
def einsum(self, pattern, *x): |
|
raise NotImplementedError("backend does not support einsum") |
|
|
|
|
|
class UnknownSize: |
|
"""pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements""" |
|
|
|
def __floordiv__(self, other): |
|
return self |
|
|
|
def __eq__(self, other): |
|
return True |
|
|
|
def __mul__(self, other): |
|
return self |
|
|
|
def __rmul__(self, other): |
|
return self |
|
|
|
def __hash__(self): |
|
return hash(None) |
|
|
|
|
|
class NumpyBackend(AbstractBackend): |
|
framework_name = "numpy" |
|
|
|
def __init__(self): |
|
import numpy |
|
|
|
self.np = numpy |
|
|
|
def is_appropriate_type(self, tensor): |
|
return isinstance(tensor, self.np.ndarray) |
|
|
|
def from_numpy(self, x): |
|
return x |
|
|
|
def to_numpy(self, x): |
|
return x |
|
|
|
def arange(self, start, stop): |
|
return self.np.arange(start, stop) |
|
|
|
def stack_on_zeroth_dimension(self, tensors: list): |
|
return self.np.stack(tensors) |
|
|
|
def tile(self, x, repeats): |
|
return self.np.tile(x, repeats) |
|
|
|
def concat(self, tensors, axis: int): |
|
return self.np.concatenate(tensors, axis=axis) |
|
|
|
def is_float_type(self, x): |
|
return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") |
|
|
|
def add_axis(self, x, new_position): |
|
return self.np.expand_dims(x, new_position) |
|
|
|
def einsum(self, pattern, *x): |
|
return self.np.einsum(pattern, *x) |
|
|
|
|
|
class JaxBackend(NumpyBackend): |
|
framework_name = "jax" |
|
|
|
def __init__(self): |
|
super(JaxBackend, self).__init__() |
|
self.onp = self.np |
|
|
|
import jax.numpy |
|
|
|
self.np = jax.numpy |
|
|
|
def from_numpy(self, x): |
|
return self.np.asarray(x) |
|
|
|
def to_numpy(self, x): |
|
return self.onp.asarray(x) |
|
|
|
|
|
class TorchBackend(AbstractBackend): |
|
framework_name = "torch" |
|
|
|
def __init__(self): |
|
import torch |
|
|
|
self.torch = torch |
|
|
|
from . import _torch_specific |
|
|
|
def is_appropriate_type(self, tensor): |
|
return isinstance(tensor, self.torch.Tensor) |
|
|
|
def from_numpy(self, x): |
|
variable = self.torch.from_numpy(x) |
|
if self.is_float_type(variable): |
|
|
|
variable.requires_grad = True |
|
return variable |
|
|
|
def to_numpy(self, x): |
|
return x.detach().cpu().numpy() |
|
|
|
def arange(self, start, stop): |
|
return self.torch.arange(start, stop, dtype=self.torch.int64) |
|
|
|
def reduce(self, x, operation, reduced_axes): |
|
if operation == "min": |
|
return x.amin(dim=reduced_axes) |
|
elif operation == "max": |
|
return x.amax(dim=reduced_axes) |
|
elif operation == "sum": |
|
return x.sum(dim=reduced_axes) |
|
elif operation == "mean": |
|
return x.mean(dim=reduced_axes) |
|
elif operation in ("any", "all", "prod"): |
|
|
|
for i in list(sorted(reduced_axes))[::-1]: |
|
x = getattr(x, operation)(dim=i) |
|
return x |
|
else: |
|
raise NotImplementedError("Unknown reduction ", operation) |
|
|
|
def transpose(self, x, axes): |
|
return x.permute(axes) |
|
|
|
def stack_on_zeroth_dimension(self, tensors: list): |
|
return self.torch.stack(tensors) |
|
|
|
def add_axes(self, x, n_axes, pos2len): |
|
repeats = [-1] * n_axes |
|
for axis_position, axis_length in pos2len.items(): |
|
x = self.add_axis(x, axis_position) |
|
repeats[axis_position] = axis_length |
|
return x.expand(repeats) |
|
|
|
def tile(self, x, repeats): |
|
return x.repeat(repeats) |
|
|
|
def concat(self, tensors, axis: int): |
|
return self.torch.cat(tensors, dim=axis) |
|
|
|
def add_axis(self, x, new_position): |
|
return self.torch.unsqueeze(x, new_position) |
|
|
|
def is_float_type(self, x): |
|
return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16] |
|
|
|
def layers(self): |
|
from .layers import torch |
|
|
|
return torch |
|
|
|
def einsum(self, pattern, *x): |
|
return self.torch.einsum(pattern, *x) |
|
|
|
|
|
class CupyBackend(AbstractBackend): |
|
framework_name = "cupy" |
|
|
|
def __init__(self): |
|
import cupy |
|
|
|
self.cupy = cupy |
|
|
|
def is_appropriate_type(self, tensor): |
|
return isinstance(tensor, self.cupy.ndarray) |
|
|
|
def from_numpy(self, x): |
|
return self.cupy.asarray(x) |
|
|
|
def to_numpy(self, x): |
|
return self.cupy.asnumpy(x) |
|
|
|
def arange(self, start, stop): |
|
return self.cupy.arange(start, stop) |
|
|
|
def stack_on_zeroth_dimension(self, tensors: list): |
|
return self.cupy.stack(tensors) |
|
|
|
def tile(self, x, repeats): |
|
return self.cupy.tile(x, repeats) |
|
|
|
def concat(self, tensors, axis: int): |
|
return self.cupy.concatenate(tensors, axis=axis) |
|
|
|
def add_axis(self, x, new_position): |
|
return self.cupy.expand_dims(x, new_position) |
|
|
|
def is_float_type(self, x): |
|
return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") |
|
|
|
def einsum(self, pattern, *x): |
|
return self.cupy.einsum(pattern, *x) |
|
|
|
|
|
class HashableTuple: |
|
"""Overcomes non-hashability of symbolic elements""" |
|
|
|
def __init__(self, elements: tuple): |
|
self.elements = elements |
|
|
|
def __iter__(self): |
|
for x in self.elements: |
|
yield x |
|
|
|
def __len__(self): |
|
return len(self.elements) |
|
|
|
def __getitem__(self, item): |
|
return self.elements[item] |
|
|
|
|
|
|
|
|
|
class TensorflowBackend(AbstractBackend): |
|
framework_name = "tensorflow" |
|
|
|
def __init__(self): |
|
import tensorflow |
|
|
|
self.tf = tensorflow |
|
|
|
def is_appropriate_type(self, tensor): |
|
return isinstance(tensor, (self.tf.Tensor, self.tf.Variable)) |
|
|
|
def from_numpy(self, x): |
|
assert self.tf.executing_eagerly() |
|
return self.tf.convert_to_tensor(x) |
|
|
|
def to_numpy(self, x): |
|
assert self.tf.executing_eagerly() |
|
return x.numpy() |
|
|
|
def arange(self, start, stop): |
|
return self.tf.range(start, stop) |
|
|
|
def shape(self, x): |
|
if self.tf.executing_eagerly(): |
|
return tuple(UnknownSize() if d is None else int(d) for d in x.shape) |
|
else: |
|
static_shape = x.shape.as_list() |
|
tf_shape = self.tf.shape(x) |
|
|
|
shape = tuple([s or tf_shape[dim] for dim, s in enumerate(static_shape)]) |
|
try: |
|
hash(shape) |
|
return shape |
|
except BaseException: |
|
|
|
return HashableTuple(shape) |
|
|
|
def reduce(self, x, operation, axes): |
|
return getattr(self.tf, "reduce_" + operation)(x, axis=axes) |
|
|
|
def reshape(self, x, shape): |
|
return self.tf.reshape(x, shape) |
|
|
|
def transpose(self, x, axes): |
|
return self.tf.transpose(x, axes) |
|
|
|
def stack_on_zeroth_dimension(self, tensors: list): |
|
return self.tf.stack(tensors) |
|
|
|
def tile(self, x, repeats): |
|
return self.tf.tile(x, repeats) |
|
|
|
def concat(self, tensors, axis: int): |
|
return self.tf.concat(tensors, axis=axis) |
|
|
|
def add_axis(self, x, new_position): |
|
return self.tf.expand_dims(x, new_position) |
|
|
|
def is_float_type(self, x): |
|
return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") |
|
|
|
def layers(self): |
|
from .layers import tensorflow |
|
|
|
return tensorflow |
|
|
|
def einsum(self, pattern, *x): |
|
return self.tf.einsum(pattern, *x) |
|
|
|
|
|
class TFKerasBackend(AbstractBackend): |
|
framework_name = "tensorflow.keras" |
|
|
|
def __init__(self): |
|
import tensorflow as tf |
|
|
|
self.tf = tf |
|
self.keras = tf.keras |
|
self.K = tf.keras.backend |
|
|
|
def is_appropriate_type(self, tensor): |
|
return self.tf.is_tensor(tensor) and self.K.is_keras_tensor(tensor) |
|
|
|
def create_symbol(self, shape): |
|
return self.keras.Input(batch_shape=shape) |
|
|
|
def eval_symbol(self, symbol, symbol_value_pairs): |
|
model = self.keras.models.Model([var for (var, _) in symbol_value_pairs], symbol) |
|
return model.predict_on_batch([val for (_, val) in symbol_value_pairs]) |
|
|
|
def arange(self, start, stop): |
|
return self.K.arange(start, stop) |
|
|
|
def shape(self, x): |
|
shape = self.K.shape(x) |
|
return HashableTuple(tuple(shape)) |
|
|
|
def reduce(self, x, operation, axes): |
|
return getattr(self.K, operation)(x, axis=axes) |
|
|
|
def reshape(self, x, shape): |
|
return self.K.reshape(x, shape) |
|
|
|
def transpose(self, x, axes): |
|
return self.K.permute_dimensions(x, axes) |
|
|
|
def stack_on_zeroth_dimension(self, tensors: list): |
|
return self.K.stack(tensors) |
|
|
|
def tile(self, x, repeats): |
|
return self.K.tile(x, repeats) |
|
|
|
def concat(self, tensors, axis: int): |
|
return self.K.concatenate(tensors, axis=axis) |
|
|
|
def add_axis(self, x, new_position): |
|
return self.K.expand_dims(x, new_position) |
|
|
|
def is_float_type(self, x): |
|
return "float" in self.K.dtype(x) |
|
|
|
def layers(self): |
|
from .layers import keras |
|
|
|
return keras |
|
|
|
|
|
class OneFlowBackend(AbstractBackend): |
|
framework_name = "oneflow" |
|
|
|
def __init__(self): |
|
import oneflow as flow |
|
|
|
self.flow = flow |
|
|
|
def is_appropriate_type(self, tensor): |
|
return isinstance(tensor, self.flow.Tensor) |
|
|
|
def from_numpy(self, x): |
|
variable = self.flow.from_numpy(x) |
|
if self.is_float_type(variable): |
|
|
|
variable.requires_grad = True |
|
return variable |
|
|
|
def to_numpy(self, x): |
|
return x.detach().cpu().numpy() |
|
|
|
def arange(self, start, stop): |
|
return self.flow.arange(start, stop, dtype=self.flow.int64) |
|
|
|
def reduce(self, x, operation, reduced_axes): |
|
for axis in sorted(reduced_axes, reverse=True): |
|
if operation == "min": |
|
x, _ = x.min(dim=axis) |
|
elif operation == "max": |
|
x, _ = x.max(dim=axis) |
|
elif operation in ["sum", "mean", "prod", "any", "all"]: |
|
x = getattr(x, operation)(dim=axis) |
|
else: |
|
raise NotImplementedError("Unknown reduction ", operation) |
|
return x |
|
|
|
def transpose(self, x, axes): |
|
return x.permute(axes) |
|
|
|
def stack_on_zeroth_dimension(self, tensors: list): |
|
return self.flow.stack(tensors) |
|
|
|
def add_axes(self, x, n_axes, pos2len): |
|
repeats = [-1] * n_axes |
|
for axis_position, axis_length in pos2len.items(): |
|
x = self.add_axis(x, axis_position) |
|
repeats[axis_position] = axis_length |
|
return x.expand(*repeats) |
|
|
|
def tile(self, x, repeats): |
|
return x.repeat(repeats) |
|
|
|
def concat(self, tensors, axis: int): |
|
return self.flow.concat(tensors, dim=axis) |
|
|
|
def add_axis(self, x, new_position): |
|
return self.flow.unsqueeze(x, new_position) |
|
|
|
def is_float_type(self, x): |
|
return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64] |
|
|
|
def layers(self): |
|
from .layers import oneflow |
|
|
|
return oneflow |
|
|
|
def einsum(self, pattern, *x): |
|
return self.flow.einsum(pattern, *x) |
|
|
|
|
|
class PaddleBackend(AbstractBackend): |
|
framework_name = "paddle" |
|
|
|
def __init__(self): |
|
import paddle |
|
|
|
self.paddle = paddle |
|
|
|
def is_appropriate_type(self, tensor): |
|
return self.paddle.is_tensor(tensor) |
|
|
|
def from_numpy(self, x): |
|
tensor = self.paddle.to_tensor(x) |
|
tensor.stop_gradient = False |
|
return tensor |
|
|
|
def to_numpy(self, x): |
|
return x.detach().numpy() |
|
|
|
def arange(self, start, stop): |
|
return self.paddle.arange(start, stop, dtype=self.paddle.int64) |
|
|
|
def reduce(self, x, operation, axes): |
|
if len(axes) == x.ndim: |
|
|
|
return super().reduce(x, operation, axes).squeeze(0) |
|
else: |
|
return super().reduce(x, operation, axes) |
|
|
|
def transpose(self, x, axes): |
|
return x.transpose(axes) |
|
|
|
def add_axes(self, x, n_axes, pos2len): |
|
repeats = [-1] * n_axes |
|
for axis_position, axis_length in pos2len.items(): |
|
x = self.add_axis(x, axis_position) |
|
repeats[axis_position] = axis_length |
|
return x.expand(repeats) |
|
|
|
def stack_on_zeroth_dimension(self, tensors: list): |
|
return self.paddle.stack(tensors) |
|
|
|
def reshape(self, x, shape): |
|
return x.reshape(shape) |
|
|
|
def tile(self, x, repeats): |
|
return x.tile(repeats) |
|
|
|
def concat(self, tensors, axis: int): |
|
return self.paddle.concat(tensors, axis=axis) |
|
|
|
def add_axis(self, x, new_position): |
|
return x.unsqueeze(new_position) |
|
|
|
def is_float_type(self, x): |
|
return x.dtype in [self.paddle.float16, self.paddle.float32, self.paddle.float64] |
|
|
|
def layers(self): |
|
from .layers import paddle |
|
|
|
return paddle |
|
|
|
def einsum(self, pattern, *x): |
|
return self.paddle.einsum(pattern, *x) |
|
|
|
def shape(self, x): |
|
return tuple(x.shape) |
|
|
|
|
|
class TinygradBackend(AbstractBackend): |
|
framework_name = "tinygrad" |
|
|
|
def __init__(self): |
|
import tinygrad |
|
|
|
self.tinygrad = tinygrad |
|
|
|
def is_appropriate_type(self, tensor): |
|
return isinstance(tensor, self.tinygrad.Tensor) |
|
|
|
def from_numpy(self, x): |
|
return self.tinygrad.Tensor(x) |
|
|
|
def to_numpy(self, x): |
|
return x.numpy() |
|
|
|
def arange(self, start, stop): |
|
return self.tinygrad.Tensor.arange(start, stop) |
|
|
|
def shape(self, x): |
|
return x.shape |
|
|
|
def reshape(self, x, shape): |
|
return x.reshape(shape) |
|
|
|
def transpose(self, x, axes): |
|
return x.permute(axes) |
|
|
|
def reduce(self, x, operation, axes): |
|
for axis in sorted(axes, reverse=True): |
|
x = getattr(x, operation)(axis=axis) |
|
return x |
|
|
|
def stack_on_zeroth_dimension(self, tensors: list): |
|
return self.tinygrad.Tensor.stack(tensors) |
|
|
|
def add_axis(self, x, new_position): |
|
return x.unsqueeze(new_position) |
|
|
|
def tile(self, x, repeats): |
|
return x.repeat(repeats) |
|
|
|
def concat(self, tensors, axis: int): |
|
return tensors[0].cat(*tensors[1:], dim=axis) if len(tensors) > 1 else tensors[0] |
|
|
|
def is_float_type(self, x): |
|
return self.tinygrad.dtypes.is_float(x.dtype) |
|
|
|
def einsum(self, pattern, *x): |
|
return self.tinygrad.Tensor.einsum(pattern, *x) |
|
|
|
|
|
class PyTensorBackend(AbstractBackend): |
|
framework_name = "pytensor" |
|
|
|
def __init__(self): |
|
from pytensor import tensor |
|
|
|
self.pt = tensor |
|
|
|
def is_appropriate_type(self, tensor): |
|
return isinstance(tensor, self.pt.TensorVariable) |
|
|
|
def is_float_type(self, x): |
|
return x.dtype in self.pt.type.float_dtypes |
|
|
|
def from_numpy(self, x): |
|
return self.pt.as_tensor(x) |
|
|
|
def to_numpy(self, x): |
|
return x.eval() |
|
|
|
def create_symbol(self, shape): |
|
if not isinstance(shape, tuple | list): |
|
shape = (shape,) |
|
return self.pt.tensor(shape=shape) |
|
|
|
def eval_symbol(self, symbol, symbol_value_pairs): |
|
return symbol.eval(dict(symbol_value_pairs)) |
|
|
|
def arange(self, start, stop): |
|
return self.pt.arange(start, stop) |
|
|
|
def shape(self, x): |
|
|
|
return tuple( |
|
static_dim if static_dim is not None else symbolic_dim |
|
for static_dim, symbolic_dim in zip(x.type.shape, x.shape) |
|
) |
|
|
|
def stack_on_zeroth_dimension(self, tensors: list): |
|
return self.pt.stack(tensors) |
|
|
|
def tile(self, x, repeats): |
|
return self.pt.tile(x, repeats) |
|
|
|
def concat(self, tensors, axis: int): |
|
return self.pt.concatenate(tensors, axis=axis) |
|
|
|
def add_axis(self, x, new_position): |
|
return self.pt.expand_dims(x, new_position) |
|
|
|
def einsum(self, pattern, *x): |
|
return self.pt.einsum(pattern, *x) |
|
|