File size: 21,281 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 |
"""
Backends in `einops` are organized to meet the following requirements
- backends are not imported unless those are actually needed, because
- backends may not be installed
- importing all available backends will drive to significant memory footprint
- backends may be present but installed with errors (but never used),
importing may drive to crashes
- backend should be either symbolic or imperative
- this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined
- if backend can't provide symbols for shape dimensions, UnknownSize objects are used
"""
import sys
__author__ = "Alex Rogozhnikov"
_loaded_backends: dict = {}
_type2backend: dict = {}
_debug_importing = False
def get_backend(tensor) -> "AbstractBackend":
"""
Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
If needed, imports package and creates backend
"""
_type = type(tensor)
_result = _type2backend.get(_type, None)
if _result is not None:
return _result
for framework_name, backend in list(_loaded_backends.items()):
if backend.is_appropriate_type(tensor):
_type2backend[_type] = backend
return backend
# Find backend subclasses recursively
backend_subclasses = []
backends = AbstractBackend.__subclasses__()
while backends:
backend = backends.pop()
backends += backend.__subclasses__()
backend_subclasses.append(backend)
for BackendSubclass in backend_subclasses:
if _debug_importing:
print("Testing for subclass of ", BackendSubclass)
if BackendSubclass.framework_name not in _loaded_backends:
# check that module was already imported. Otherwise it can't be imported
if BackendSubclass.framework_name in sys.modules:
if _debug_importing:
print("Imported backend for ", BackendSubclass.framework_name)
backend = BackendSubclass()
_loaded_backends[backend.framework_name] = backend
if backend.is_appropriate_type(tensor):
_type2backend[_type] = backend
return backend
raise RuntimeError("Tensor type unknown to einops {}".format(type(tensor)))
class AbstractBackend:
"""Base backend class, major part of methods are only for debugging purposes."""
framework_name: str
def is_appropriate_type(self, tensor):
"""helper method should recognize tensors it can handle"""
raise NotImplementedError()
def from_numpy(self, x):
raise NotImplementedError("framework doesn't support imperative execution")
def to_numpy(self, x):
raise NotImplementedError("framework doesn't support imperative execution")
def create_symbol(self, shape):
raise NotImplementedError("framework doesn't support symbolic computations")
def eval_symbol(self, symbol, symbol_value_pairs):
# symbol-value pairs is list[tuple[symbol, value-tensor]]
raise NotImplementedError("framework doesn't support symbolic computations")
def arange(self, start, stop):
# supplementary method used only in testing, so should implement CPU version
raise NotImplementedError("framework doesn't implement arange")
def shape(self, x):
"""shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)"""
return x.shape
def reshape(self, x, shape):
return x.reshape(shape)
def transpose(self, x, axes):
return x.transpose(axes)
def reduce(self, x, operation, axes):
return getattr(x, operation)(axis=axes)
def stack_on_zeroth_dimension(self, tensors: list):
raise NotImplementedError()
def add_axis(self, x, new_position):
raise NotImplementedError()
def add_axes(self, x, n_axes, pos2len):
repeats = [1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return self.tile(x, tuple(repeats))
def tile(self, x, repeats):
"""repeats - same lengths as x.shape"""
raise NotImplementedError()
def concat(self, tensors, axis: int):
"""concatenates tensors along axis.
Assume identical across tensors: devices, dtypes and shapes except selected axis."""
raise NotImplementedError()
def is_float_type(self, x):
# some backends (torch) can't compute average for non-floating types.
# Decided to drop average for all backends if type is not floating
raise NotImplementedError()
def layers(self):
raise NotImplementedError("backend does not provide layers")
def __repr__(self):
return "<einops backend for {}>".format(self.framework_name)
def einsum(self, pattern, *x):
raise NotImplementedError("backend does not support einsum")
class UnknownSize:
"""pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements"""
def __floordiv__(self, other):
return self
def __eq__(self, other):
return True # we don't know actual size
def __mul__(self, other):
return self
def __rmul__(self, other):
return self
def __hash__(self):
return hash(None)
class NumpyBackend(AbstractBackend):
framework_name = "numpy"
def __init__(self):
import numpy
self.np = numpy
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.np.ndarray)
def from_numpy(self, x):
return x
def to_numpy(self, x):
return x
def arange(self, start, stop):
return self.np.arange(start, stop)
def stack_on_zeroth_dimension(self, tensors: list):
return self.np.stack(tensors)
def tile(self, x, repeats):
return self.np.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.np.concatenate(tensors, axis=axis)
def is_float_type(self, x):
return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
def add_axis(self, x, new_position):
return self.np.expand_dims(x, new_position)
def einsum(self, pattern, *x):
return self.np.einsum(pattern, *x)
class JaxBackend(NumpyBackend):
framework_name = "jax"
def __init__(self):
super(JaxBackend, self).__init__()
self.onp = self.np
import jax.numpy
self.np = jax.numpy
def from_numpy(self, x):
return self.np.asarray(x)
def to_numpy(self, x):
return self.onp.asarray(x)
class TorchBackend(AbstractBackend):
framework_name = "torch"
def __init__(self):
import torch
self.torch = torch
# importing would register operations in torch._dynamo for torch.compile
from . import _torch_specific # noqa
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.torch.Tensor)
def from_numpy(self, x):
variable = self.torch.from_numpy(x)
if self.is_float_type(variable):
# attach grad only to floating types
variable.requires_grad = True
return variable
def to_numpy(self, x):
return x.detach().cpu().numpy()
def arange(self, start, stop):
return self.torch.arange(start, stop, dtype=self.torch.int64)
def reduce(self, x, operation, reduced_axes):
if operation == "min":
return x.amin(dim=reduced_axes)
elif operation == "max":
return x.amax(dim=reduced_axes)
elif operation == "sum":
return x.sum(dim=reduced_axes)
elif operation == "mean":
return x.mean(dim=reduced_axes)
elif operation in ("any", "all", "prod"):
# pytorch supports reducing only one operation at a time
for i in list(sorted(reduced_axes))[::-1]:
x = getattr(x, operation)(dim=i)
return x
else:
raise NotImplementedError("Unknown reduction ", operation)
def transpose(self, x, axes):
return x.permute(axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.torch.stack(tensors)
def add_axes(self, x, n_axes, pos2len):
repeats = [-1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return x.expand(repeats)
def tile(self, x, repeats):
return x.repeat(repeats)
def concat(self, tensors, axis: int):
return self.torch.cat(tensors, dim=axis)
def add_axis(self, x, new_position):
return self.torch.unsqueeze(x, new_position)
def is_float_type(self, x):
return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16]
def layers(self):
from .layers import torch
return torch
def einsum(self, pattern, *x):
return self.torch.einsum(pattern, *x)
class CupyBackend(AbstractBackend):
framework_name = "cupy"
def __init__(self):
import cupy
self.cupy = cupy
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.cupy.ndarray)
def from_numpy(self, x):
return self.cupy.asarray(x)
def to_numpy(self, x):
return self.cupy.asnumpy(x)
def arange(self, start, stop):
return self.cupy.arange(start, stop)
def stack_on_zeroth_dimension(self, tensors: list):
return self.cupy.stack(tensors)
def tile(self, x, repeats):
return self.cupy.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.cupy.concatenate(tensors, axis=axis)
def add_axis(self, x, new_position):
return self.cupy.expand_dims(x, new_position)
def is_float_type(self, x):
return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
def einsum(self, pattern, *x):
return self.cupy.einsum(pattern, *x)
class HashableTuple:
"""Overcomes non-hashability of symbolic elements"""
def __init__(self, elements: tuple):
self.elements = elements
def __iter__(self):
for x in self.elements:
yield x
def __len__(self):
return len(self.elements)
def __getitem__(self, item):
return self.elements[item]
# default equality and hash is used (True only with itself, hash taken of id)
class TensorflowBackend(AbstractBackend):
framework_name = "tensorflow"
def __init__(self):
import tensorflow
self.tf = tensorflow
def is_appropriate_type(self, tensor):
return isinstance(tensor, (self.tf.Tensor, self.tf.Variable))
def from_numpy(self, x):
assert self.tf.executing_eagerly()
return self.tf.convert_to_tensor(x)
def to_numpy(self, x):
assert self.tf.executing_eagerly()
return x.numpy()
def arange(self, start, stop):
return self.tf.range(start, stop)
def shape(self, x):
if self.tf.executing_eagerly():
return tuple(UnknownSize() if d is None else int(d) for d in x.shape)
else:
static_shape = x.shape.as_list()
tf_shape = self.tf.shape(x)
# use the static shape where known, otherwise use the TF shape components
shape = tuple([s or tf_shape[dim] for dim, s in enumerate(static_shape)])
try:
hash(shape)
return shape
except BaseException:
# unhashable symbols in shape. Wrap tuple to be hashable.
return HashableTuple(shape)
def reduce(self, x, operation, axes):
return getattr(self.tf, "reduce_" + operation)(x, axis=axes)
def reshape(self, x, shape):
return self.tf.reshape(x, shape)
def transpose(self, x, axes):
return self.tf.transpose(x, axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.tf.stack(tensors)
def tile(self, x, repeats):
return self.tf.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.tf.concat(tensors, axis=axis)
def add_axis(self, x, new_position):
return self.tf.expand_dims(x, new_position)
def is_float_type(self, x):
return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
def layers(self):
from .layers import tensorflow
return tensorflow
def einsum(self, pattern, *x):
return self.tf.einsum(pattern, *x)
class TFKerasBackend(AbstractBackend):
framework_name = "tensorflow.keras"
def __init__(self):
import tensorflow as tf
self.tf = tf
self.keras = tf.keras
self.K = tf.keras.backend
def is_appropriate_type(self, tensor):
return self.tf.is_tensor(tensor) and self.K.is_keras_tensor(tensor)
def create_symbol(self, shape):
return self.keras.Input(batch_shape=shape)
def eval_symbol(self, symbol, symbol_value_pairs):
model = self.keras.models.Model([var for (var, _) in symbol_value_pairs], symbol)
return model.predict_on_batch([val for (_, val) in symbol_value_pairs])
def arange(self, start, stop):
return self.K.arange(start, stop)
def shape(self, x):
shape = self.K.shape(x) # tf tensor
return HashableTuple(tuple(shape))
def reduce(self, x, operation, axes):
return getattr(self.K, operation)(x, axis=axes)
def reshape(self, x, shape):
return self.K.reshape(x, shape)
def transpose(self, x, axes):
return self.K.permute_dimensions(x, axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.K.stack(tensors)
def tile(self, x, repeats):
return self.K.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.K.concatenate(tensors, axis=axis)
def add_axis(self, x, new_position):
return self.K.expand_dims(x, new_position)
def is_float_type(self, x):
return "float" in self.K.dtype(x)
def layers(self):
from .layers import keras
return keras
class OneFlowBackend(AbstractBackend):
framework_name = "oneflow"
def __init__(self):
import oneflow as flow
self.flow = flow
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.flow.Tensor)
def from_numpy(self, x):
variable = self.flow.from_numpy(x)
if self.is_float_type(variable):
# attach grad only to floating types
variable.requires_grad = True
return variable
def to_numpy(self, x):
return x.detach().cpu().numpy()
def arange(self, start, stop):
return self.flow.arange(start, stop, dtype=self.flow.int64)
def reduce(self, x, operation, reduced_axes):
for axis in sorted(reduced_axes, reverse=True):
if operation == "min":
x, _ = x.min(dim=axis)
elif operation == "max":
x, _ = x.max(dim=axis)
elif operation in ["sum", "mean", "prod", "any", "all"]:
x = getattr(x, operation)(dim=axis)
else:
raise NotImplementedError("Unknown reduction ", operation)
return x
def transpose(self, x, axes):
return x.permute(axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.flow.stack(tensors)
def add_axes(self, x, n_axes, pos2len):
repeats = [-1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return x.expand(*repeats)
def tile(self, x, repeats):
return x.repeat(repeats)
def concat(self, tensors, axis: int):
return self.flow.concat(tensors, dim=axis)
def add_axis(self, x, new_position):
return self.flow.unsqueeze(x, new_position)
def is_float_type(self, x):
return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64]
def layers(self):
from .layers import oneflow
return oneflow
def einsum(self, pattern, *x):
return self.flow.einsum(pattern, *x)
class PaddleBackend(AbstractBackend):
framework_name = "paddle"
def __init__(self):
import paddle
self.paddle = paddle
def is_appropriate_type(self, tensor):
return self.paddle.is_tensor(tensor)
def from_numpy(self, x):
tensor = self.paddle.to_tensor(x)
tensor.stop_gradient = False
return tensor
def to_numpy(self, x):
return x.detach().numpy()
def arange(self, start, stop):
return self.paddle.arange(start, stop, dtype=self.paddle.int64)
def reduce(self, x, operation, axes):
if len(axes) == x.ndim:
# currently paddle returns 1d tensor instead of 0d
return super().reduce(x, operation, axes).squeeze(0)
else:
return super().reduce(x, operation, axes)
def transpose(self, x, axes):
return x.transpose(axes)
def add_axes(self, x, n_axes, pos2len):
repeats = [-1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return x.expand(repeats)
def stack_on_zeroth_dimension(self, tensors: list):
return self.paddle.stack(tensors)
def reshape(self, x, shape):
return x.reshape(shape)
def tile(self, x, repeats):
return x.tile(repeats)
def concat(self, tensors, axis: int):
return self.paddle.concat(tensors, axis=axis)
def add_axis(self, x, new_position):
return x.unsqueeze(new_position)
def is_float_type(self, x):
return x.dtype in [self.paddle.float16, self.paddle.float32, self.paddle.float64]
def layers(self):
from .layers import paddle
return paddle
def einsum(self, pattern, *x):
return self.paddle.einsum(pattern, *x)
def shape(self, x):
return tuple(x.shape)
class TinygradBackend(AbstractBackend):
framework_name = "tinygrad"
def __init__(self):
import tinygrad
self.tinygrad = tinygrad
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.tinygrad.Tensor)
def from_numpy(self, x):
return self.tinygrad.Tensor(x)
def to_numpy(self, x):
return x.numpy()
def arange(self, start, stop):
return self.tinygrad.Tensor.arange(start, stop)
def shape(self, x):
return x.shape
def reshape(self, x, shape):
return x.reshape(shape)
def transpose(self, x, axes):
return x.permute(axes)
def reduce(self, x, operation, axes):
for axis in sorted(axes, reverse=True):
x = getattr(x, operation)(axis=axis)
return x
def stack_on_zeroth_dimension(self, tensors: list):
return self.tinygrad.Tensor.stack(tensors)
def add_axis(self, x, new_position):
return x.unsqueeze(new_position)
def tile(self, x, repeats):
return x.repeat(repeats)
def concat(self, tensors, axis: int):
return tensors[0].cat(*tensors[1:], dim=axis) if len(tensors) > 1 else tensors[0]
def is_float_type(self, x):
return self.tinygrad.dtypes.is_float(x.dtype)
def einsum(self, pattern, *x):
return self.tinygrad.Tensor.einsum(pattern, *x)
class PyTensorBackend(AbstractBackend):
framework_name = "pytensor"
def __init__(self):
from pytensor import tensor
self.pt = tensor
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.pt.TensorVariable)
def is_float_type(self, x):
return x.dtype in self.pt.type.float_dtypes
def from_numpy(self, x):
return self.pt.as_tensor(x)
def to_numpy(self, x):
return x.eval() # Will only work if there are no symbolic inputs
def create_symbol(self, shape):
if not isinstance(shape, tuple | list):
shape = (shape,)
return self.pt.tensor(shape=shape)
def eval_symbol(self, symbol, symbol_value_pairs):
return symbol.eval(dict(symbol_value_pairs))
def arange(self, start, stop):
return self.pt.arange(start, stop)
def shape(self, x):
# use the static shape dimensions where known
return tuple(
static_dim if static_dim is not None else symbolic_dim
for static_dim, symbolic_dim in zip(x.type.shape, x.shape)
)
def stack_on_zeroth_dimension(self, tensors: list):
return self.pt.stack(tensors)
def tile(self, x, repeats):
return self.pt.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.pt.concatenate(tensors, axis=axis)
def add_axis(self, x, new_position):
return self.pt.expand_dims(x, new_position)
def einsum(self, pattern, *x):
return self.pt.einsum(pattern, *x)
|