|
import pickle |
|
from collections import namedtuple |
|
|
|
import numpy |
|
import pytest |
|
|
|
from einops import rearrange, reduce, EinopsError |
|
from einops.tests import collect_test_backends, is_backend_tested, FLOAT_REDUCTIONS as REDUCTIONS |
|
|
|
__author__ = "Alex Rogozhnikov" |
|
|
|
testcase = namedtuple("testcase", ["pattern", "axes_lengths", "input_shape", "wrong_shapes"]) |
|
|
|
rearrangement_patterns = [ |
|
testcase( |
|
"b c h w -> b (c h w)", |
|
dict(c=20), |
|
(10, 20, 30, 40), |
|
[(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]], |
|
), |
|
testcase( |
|
"b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1", |
|
dict(h2=2, w2=2), |
|
(10, 20, 30, 40), |
|
[(), (1, 1, 1, 1), (1, 10, 3), ()], |
|
), |
|
testcase( |
|
"b ... c -> c b ...", |
|
dict(b=10), |
|
(10, 20, 30), |
|
[(), (10,), (5, 10)], |
|
), |
|
] |
|
|
|
|
|
def test_rearrange_imperative(): |
|
for backend in collect_test_backends(symbolic=False, layers=True): |
|
print("Test layer for ", backend.framework_name) |
|
|
|
for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns: |
|
x = numpy.arange(numpy.prod(input_shape), dtype="float32").reshape(input_shape) |
|
result_numpy = rearrange(x, pattern, **axes_lengths) |
|
layer = backend.layers().Rearrange(pattern, **axes_lengths) |
|
for shape in wrong_shapes: |
|
try: |
|
layer(backend.from_numpy(numpy.zeros(shape, dtype="float32"))) |
|
except BaseException: |
|
pass |
|
else: |
|
raise AssertionError("Failure expected") |
|
|
|
|
|
layer2 = pickle.loads(pickle.dumps(layer)) |
|
result1 = backend.to_numpy(layer(backend.from_numpy(x))) |
|
result2 = backend.to_numpy(layer2(backend.from_numpy(x))) |
|
assert numpy.allclose(result_numpy, result1) |
|
assert numpy.allclose(result1, result2) |
|
|
|
just_sum = backend.layers().Reduce("...->", reduction="sum") |
|
|
|
variable = backend.from_numpy(x) |
|
result = just_sum(layer(variable)) |
|
|
|
result.backward() |
|
assert numpy.allclose(backend.to_numpy(variable.grad), 1) |
|
|
|
|
|
def test_rearrange_symbolic(): |
|
for backend in collect_test_backends(symbolic=True, layers=True): |
|
print("Test layer for ", backend.framework_name) |
|
|
|
for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns: |
|
x = numpy.arange(numpy.prod(input_shape), dtype="float32").reshape(input_shape) |
|
result_numpy = rearrange(x, pattern, **axes_lengths) |
|
layer = backend.layers().Rearrange(pattern, **axes_lengths) |
|
input_shape_of_nones = [None] * len(input_shape) |
|
shapes = [input_shape, input_shape_of_nones] |
|
|
|
for shape in shapes: |
|
symbol = backend.create_symbol(shape) |
|
eval_inputs = [(symbol, x)] |
|
|
|
result_symbol1 = layer(symbol) |
|
result1 = backend.eval_symbol(result_symbol1, eval_inputs) |
|
assert numpy.allclose(result_numpy, result1) |
|
|
|
layer2 = pickle.loads(pickle.dumps(layer)) |
|
result_symbol2 = layer2(symbol) |
|
result2 = backend.eval_symbol(result_symbol2, eval_inputs) |
|
assert numpy.allclose(result1, result2) |
|
|
|
|
|
just_sum = backend.layers().Reduce("...->", reduction="sum") |
|
|
|
result_sum1 = backend.eval_symbol(just_sum(result_symbol1), eval_inputs) |
|
result_sum2 = numpy.sum(x) |
|
|
|
assert numpy.allclose(result_sum1, result_sum2) |
|
|
|
|
|
reduction_patterns = rearrangement_patterns + [ |
|
testcase("b c h w -> b ()", dict(b=10), (10, 20, 30, 40), [(10,), (10, 20, 30)]), |
|
testcase("b c (h1 h2) (w1 w2) -> b c h1 w1", dict(h1=15, h2=2, w2=2), (10, 20, 30, 40), [(10, 20, 31, 40)]), |
|
testcase("b ... c -> b", dict(b=10), (10, 20, 30, 40), [(10,), (11, 10)]), |
|
] |
|
|
|
|
|
def test_reduce_imperative(): |
|
for backend in collect_test_backends(symbolic=False, layers=True): |
|
print("Test layer for ", backend.framework_name) |
|
for reduction in REDUCTIONS: |
|
for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns: |
|
print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes) |
|
x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype="float32").reshape(input_shape) |
|
x /= x.mean() |
|
result_numpy = reduce(x, pattern, reduction, **axes_lengths) |
|
layer = backend.layers().Reduce(pattern, reduction, **axes_lengths) |
|
for shape in wrong_shapes: |
|
try: |
|
layer(backend.from_numpy(numpy.zeros(shape, dtype="float32"))) |
|
except BaseException: |
|
pass |
|
else: |
|
raise AssertionError("Failure expected") |
|
|
|
|
|
layer2 = pickle.loads(pickle.dumps(layer)) |
|
result1 = backend.to_numpy(layer(backend.from_numpy(x))) |
|
result2 = backend.to_numpy(layer2(backend.from_numpy(x))) |
|
assert numpy.allclose(result_numpy, result1) |
|
assert numpy.allclose(result1, result2) |
|
|
|
just_sum = backend.layers().Reduce("...->", reduction="sum") |
|
|
|
variable = backend.from_numpy(x) |
|
result = just_sum(layer(variable)) |
|
|
|
result.backward() |
|
grad = backend.to_numpy(variable.grad) |
|
if reduction == "sum": |
|
assert numpy.allclose(grad, 1) |
|
if reduction == "mean": |
|
assert numpy.allclose(grad, grad.min()) |
|
if reduction in ["max", "min"]: |
|
assert numpy.all(numpy.in1d(grad, [0, 1])) |
|
assert numpy.sum(grad) > 0.5 |
|
|
|
|
|
def test_reduce_symbolic(): |
|
for backend in collect_test_backends(symbolic=True, layers=True): |
|
print("Test layer for ", backend.framework_name) |
|
for reduction in REDUCTIONS: |
|
for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns: |
|
x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype="float32").reshape(input_shape) |
|
x /= x.mean() |
|
result_numpy = reduce(x, pattern, reduction, **axes_lengths) |
|
layer = backend.layers().Reduce(pattern, reduction, **axes_lengths) |
|
input_shape_of_nones = [None] * len(input_shape) |
|
shapes = [input_shape, input_shape_of_nones] |
|
|
|
for shape in shapes: |
|
symbol = backend.create_symbol(shape) |
|
eval_inputs = [(symbol, x)] |
|
|
|
result_symbol1 = layer(symbol) |
|
result1 = backend.eval_symbol(result_symbol1, eval_inputs) |
|
assert numpy.allclose(result_numpy, result1) |
|
|
|
layer2 = pickle.loads(pickle.dumps(layer)) |
|
result_symbol2 = layer2(symbol) |
|
result2 = backend.eval_symbol(result_symbol2, eval_inputs) |
|
assert numpy.allclose(result1, result2) |
|
|
|
|
|
def create_torch_model(use_reduce=False, add_scripted_layer=False): |
|
if not is_backend_tested("torch"): |
|
pytest.skip() |
|
else: |
|
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU |
|
from einops.layers.torch import Rearrange, Reduce, EinMix |
|
import torch.jit |
|
|
|
return Sequential( |
|
Conv2d(3, 6, kernel_size=(5, 5)), |
|
Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2), |
|
Conv2d(6, 16, kernel_size=(5, 5)), |
|
Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2), |
|
torch.jit.script(Rearrange("b c h w -> b (c h w)")) |
|
if add_scripted_layer |
|
else Rearrange("b c h w -> b (c h w)"), |
|
Linear(16 * 5 * 5, 120), |
|
ReLU(), |
|
Linear(120, 84), |
|
ReLU(), |
|
EinMix("b c1 -> (b c2)", weight_shape="c1 c2", bias_shape="c2", c1=84, c2=84), |
|
EinMix("(b c2) -> b c3", weight_shape="c2 c3", bias_shape="c3", c2=84, c3=84), |
|
Linear(84, 10), |
|
) |
|
|
|
|
|
def test_torch_layer(): |
|
if not is_backend_tested("torch"): |
|
pytest.skip() |
|
else: |
|
|
|
import torch |
|
import torch.jit |
|
|
|
model1 = create_torch_model(use_reduce=True) |
|
model2 = create_torch_model(use_reduce=False) |
|
input = torch.randn([10, 3, 32, 32]) |
|
|
|
assert not torch.allclose(model1(input), model2(input)) |
|
model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict()))) |
|
assert torch.allclose(model1(input), model2(input)) |
|
|
|
|
|
model3 = torch.jit.trace(model2, example_inputs=input) |
|
torch.testing.assert_close(model1(input), model3(input), atol=1e-3, rtol=1e-3) |
|
torch.testing.assert_close(model1(input + 1), model3(input + 1), atol=1e-3, rtol=1e-3) |
|
|
|
model4 = torch.jit.trace(model2, example_inputs=input) |
|
torch.testing.assert_close(model1(input), model4(input), atol=1e-3, rtol=1e-3) |
|
torch.testing.assert_close(model1(input + 1), model4(input + 1), atol=1e-3, rtol=1e-3) |
|
|
|
|
|
def test_torch_layers_scripting(): |
|
if not is_backend_tested("torch"): |
|
pytest.skip() |
|
else: |
|
import torch |
|
|
|
for script_layer in [False, True]: |
|
model1 = create_torch_model(use_reduce=True, add_scripted_layer=script_layer) |
|
model2 = torch.jit.script(model1) |
|
input = torch.randn([10, 3, 32, 32]) |
|
|
|
torch.testing.assert_close(model1(input), model2(input), atol=1e-3, rtol=1e-3) |
|
|
|
|
|
def test_keras_layer(): |
|
if not is_backend_tested("tensorflow"): |
|
pytest.skip() |
|
else: |
|
import tensorflow as tf |
|
|
|
if tf.__version__ < "2.16.": |
|
|
|
pytest.skip() |
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.layers import Conv2D as Conv2d, Dense as Linear, ReLU |
|
from einops.layers.keras import Rearrange, Reduce, EinMix, keras_custom_objects |
|
|
|
def create_keras_model(): |
|
return Sequential( |
|
[ |
|
Conv2d(6, kernel_size=5, input_shape=[32, 32, 3]), |
|
Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2), |
|
Conv2d(16, kernel_size=5), |
|
Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2), |
|
Rearrange("b c h w -> b (c h w)"), |
|
Linear(120), |
|
ReLU(), |
|
Linear(84), |
|
ReLU(), |
|
EinMix("b c1 -> (b c2)", weight_shape="c1 c2", bias_shape="c2", c1=84, c2=84), |
|
EinMix("(b c2) -> b c3", weight_shape="c2 c3", bias_shape="c3", c2=84, c3=84), |
|
Linear(10), |
|
] |
|
) |
|
|
|
model1 = create_keras_model() |
|
model2 = create_keras_model() |
|
|
|
input = numpy.random.normal(size=[10, 32, 32, 3]).astype("float32") |
|
|
|
assert not numpy.allclose(model1.predict_on_batch(input), model2.predict_on_batch(input)) |
|
|
|
|
|
tmp_model_filename = "/tmp/einops_tf_model.h5" |
|
|
|
print("temp_path_keras1", tmp_model_filename) |
|
tf.keras.models.save_model(model1, tmp_model_filename) |
|
model3 = tf.keras.models.load_model(tmp_model_filename, custom_objects=keras_custom_objects) |
|
|
|
numpy.testing.assert_allclose(model1.predict_on_batch(input), model3.predict_on_batch(input)) |
|
|
|
weight_filename = "/tmp/einops_tf_model.weights.h5" |
|
|
|
model4 = tf.keras.models.model_from_json(model1.to_json(), custom_objects=keras_custom_objects) |
|
model1.save_weights(weight_filename) |
|
model4.load_weights(weight_filename) |
|
model2.load_weights(weight_filename) |
|
|
|
numpy.testing.assert_allclose(model1.predict_on_batch(input), model2.predict_on_batch(input)) |
|
|
|
|
|
numpy.testing.assert_allclose(model1.predict_on_batch(input), model4.predict_on_batch(input)) |
|
|
|
|
|
def test_flax_layers(): |
|
""" |
|
One-off simple tests for Flax layers. |
|
Unfortunately, Flax layers have a different interface from other layers. |
|
""" |
|
if not is_backend_tested("jax"): |
|
pytest.skip() |
|
else: |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
import flax |
|
from flax import linen as nn |
|
from einops.layers.flax import EinMix, Reduce, Rearrange |
|
|
|
class NN(nn.Module): |
|
@nn.compact |
|
def __call__(self, x): |
|
x = EinMix( |
|
"b (h h2) (w w2) c -> b h w c_out", "h2 w2 c c_out", "c_out", sizes=dict(h2=2, w2=3, c=4, c_out=5) |
|
)(x) |
|
x = Rearrange("b h w c -> b (w h c)", sizes=dict(c=5))(x) |
|
x = Reduce("b hwc -> b", "mean", dict(hwc=2 * 3 * 5))(x) |
|
return x |
|
|
|
model = NN() |
|
fixed_input = jnp.ones([10, 2 * 2, 3 * 3, 4]) |
|
params = model.init(jax.random.PRNGKey(0), fixed_input) |
|
|
|
def eval_at_point(params): |
|
return jnp.linalg.norm(model.apply(params, fixed_input)) |
|
|
|
vandg = jax.value_and_grad(eval_at_point) |
|
value0 = eval_at_point(params) |
|
value1, grad1 = vandg(params) |
|
assert jnp.allclose(value0, value1) |
|
|
|
params2 = jax.tree_map(lambda x1, x2: x1 - x2 * 0.001, params, grad1) |
|
|
|
value2 = eval_at_point(params2) |
|
assert value0 >= value2, (value0, value2) |
|
|
|
|
|
fbytes = flax.serialization.to_bytes(params) |
|
_loaded = flax.serialization.from_bytes(params, fbytes) |
|
|
|
|
|
def test_einmix_decomposition(): |
|
""" |
|
Testing that einmix correctly decomposes into smaller transformations. |
|
""" |
|
from einops.layers._einmix import _EinmixDebugger |
|
|
|
mixin1 = _EinmixDebugger( |
|
"a b c d e -> e d c b a", |
|
weight_shape="d a b", |
|
d=2, a=3, b=5, |
|
) |
|
assert mixin1.pre_reshape_pattern is None |
|
assert mixin1.post_reshape_pattern is None |
|
assert mixin1.einsum_pattern == "abcde,dab->edcba" |
|
assert mixin1.saved_weight_shape == [2, 3, 5] |
|
assert mixin1.saved_bias_shape is None |
|
|
|
mixin2 = _EinmixDebugger( |
|
"a b c d e -> e d c b a", |
|
weight_shape="d a b", |
|
bias_shape="a b c d e", |
|
a=1, b=2, c=3, d=4, e=5, |
|
) |
|
assert mixin2.pre_reshape_pattern is None |
|
assert mixin2.post_reshape_pattern is None |
|
assert mixin2.einsum_pattern == "abcde,dab->edcba" |
|
assert mixin2.saved_weight_shape == [4, 1, 2] |
|
assert mixin2.saved_bias_shape == [5, 4, 3, 2, 1] |
|
|
|
mixin3 = _EinmixDebugger( |
|
"... -> ...", |
|
weight_shape="", |
|
bias_shape="", |
|
) |
|
assert mixin3.pre_reshape_pattern is None |
|
assert mixin3.post_reshape_pattern is None |
|
assert mixin3.einsum_pattern == "...,->..." |
|
assert mixin3.saved_weight_shape == [] |
|
assert mixin3.saved_bias_shape == [] |
|
|
|
mixin4 = _EinmixDebugger( |
|
"b a ... -> b c ...", |
|
weight_shape="b a c", |
|
a=1, b=2, c=3, |
|
) |
|
assert mixin4.pre_reshape_pattern is None |
|
assert mixin4.post_reshape_pattern is None |
|
assert mixin4.einsum_pattern == "ba...,bac->bc..." |
|
assert mixin4.saved_weight_shape == [2, 1, 3] |
|
assert mixin4.saved_bias_shape is None |
|
|
|
mixin5 = _EinmixDebugger( |
|
"(b a) ... -> b c (...)", |
|
weight_shape="b a c", |
|
a=1, b=2, c=3, |
|
) |
|
assert mixin5.pre_reshape_pattern == "(b a) ... -> b a ..." |
|
assert mixin5.pre_reshape_lengths == dict(a=1, b=2) |
|
assert mixin5.post_reshape_pattern == "b c ... -> b c (...)" |
|
assert mixin5.einsum_pattern == "ba...,bac->bc..." |
|
assert mixin5.saved_weight_shape == [2, 1, 3] |
|
assert mixin5.saved_bias_shape is None |
|
|
|
mixin6 = _EinmixDebugger( |
|
"b ... (a c) -> b ... (a d)", |
|
weight_shape="c d", |
|
bias_shape="a d", |
|
a=1, c=3, d=4, |
|
) |
|
assert mixin6.pre_reshape_pattern == "b ... (a c) -> b ... a c" |
|
assert mixin6.pre_reshape_lengths == dict(a=1, c=3) |
|
assert mixin6.post_reshape_pattern == "b ... a d -> b ... (a d)" |
|
assert mixin6.einsum_pattern == "b...ac,cd->b...ad" |
|
assert mixin6.saved_weight_shape == [3, 4] |
|
assert mixin6.saved_bias_shape == [1, 1, 4] |
|
|
|
mixin7 = _EinmixDebugger( |
|
"a ... (b c) -> a (... d b)", |
|
weight_shape="c d b", |
|
bias_shape="d b", |
|
b=2, c=3, d=4, |
|
) |
|
assert mixin7.pre_reshape_pattern == "a ... (b c) -> a ... b c" |
|
assert mixin7.pre_reshape_lengths == dict(b=2, c=3) |
|
assert mixin7.post_reshape_pattern == "a ... d b -> a (... d b)" |
|
assert mixin7.einsum_pattern == "a...bc,cdb->a...db" |
|
assert mixin7.saved_weight_shape == [3, 4, 2] |
|
assert mixin7.saved_bias_shape == [1, 4, 2] |
|
|
|
|
|
def test_einmix_restrictions(): |
|
""" |
|
Testing different cases |
|
""" |
|
from einops.layers._einmix import _EinmixDebugger |
|
|
|
with pytest.raises(EinopsError): |
|
_EinmixDebugger( |
|
"a b c d e -> e d c b a", |
|
weight_shape="d a b", |
|
d=2, a=3, |
|
) |
|
|
|
with pytest.raises(EinopsError): |
|
_EinmixDebugger( |
|
"a b c d e -> e d c b a", |
|
weight_shape="w a b", |
|
d=2, a=3, b=1 |
|
) |
|
|
|
with pytest.raises(EinopsError): |
|
_EinmixDebugger( |
|
"(...) a -> ... a", |
|
weight_shape="a", a=1, |
|
) |
|
|
|
with pytest.raises(EinopsError): |
|
_EinmixDebugger( |
|
"(...) a -> a ...", |
|
weight_shape="a", a=1, |
|
bias_shape='a', |
|
) |
|
|