|
from typing import Any, Callable |
|
from einops.tests import collect_test_backends |
|
from einops.einops import _compactify_pattern_for_einsum, einsum, EinopsError |
|
import numpy as np |
|
import pytest |
|
import string |
|
|
|
|
|
class Arguments: |
|
def __init__(self, *args: Any, **kargs: Any): |
|
self.args = args |
|
self.kwargs = kargs |
|
|
|
def __call__(self, function: Callable): |
|
return function(*self.args, **self.kwargs) |
|
|
|
|
|
test_layer_cases = [ |
|
( |
|
Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape=None, c_out=13, c_in=12), |
|
(2, 12, 3, 4), |
|
(4, 13, 3, 2), |
|
), |
|
( |
|
Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape="c_out", c_out=13, c_in=12), |
|
(2, 12, 3, 4), |
|
(4, 13, 3, 2), |
|
), |
|
( |
|
Arguments("b c_in h w -> w c_in h b", "", bias_shape=None, c_in=12), |
|
(2, 12, 3, 4), |
|
(4, 12, 3, 2), |
|
), |
|
( |
|
Arguments("b c_in h w -> b c_out", "c_in h w c_out", bias_shape=None, c_in=12, h=3, w=4, c_out=5), |
|
(2, 12, 3, 4), |
|
(2, 5), |
|
), |
|
( |
|
Arguments("b t head c_in -> b t head c_out", "head c_in c_out", bias_shape=None, head=4, c_in=5, c_out=6), |
|
(2, 3, 4, 5), |
|
(2, 3, 4, 6), |
|
), |
|
] |
|
|
|
|
|
|
|
|
|
test_functional_cases = [ |
|
( |
|
|
|
"b c h w, b w -> b h", |
|
"abcd,ad->ac", |
|
((2, 3, 4, 5), (2, 5)), |
|
(2, 4), |
|
), |
|
( |
|
|
|
"b c h w, b w, b c -> b h", |
|
"abcd,ad,ab->ac", |
|
((2, 3, 40, 5), (2, 5), (2, 3)), |
|
(2, 40), |
|
), |
|
( |
|
|
|
"... one two three, three four five -> ... two five", |
|
"...abc,cde->...be", |
|
((32, 5, 2, 3, 4), (4, 5, 6)), |
|
(32, 5, 3, 6), |
|
), |
|
( |
|
|
|
"one two three ..., three four five -> two five ...", |
|
"abc...,cde->be...", |
|
((2, 3, 4, 32, 5), (4, 5, 6)), |
|
(3, 6, 32, 5), |
|
), |
|
( |
|
|
|
"... one two three, ... three four five -> ... two five", |
|
"...abc,...cde->...be", |
|
((32, 5, 2, 3, 4), (32, 5, 4, 5, 6)), |
|
(32, 5, 3, 6), |
|
), |
|
( |
|
|
|
"first_tensor second_tensor -> first_tensor", |
|
"ab->a", |
|
((5, 4),), |
|
(5,), |
|
), |
|
( |
|
|
|
"i i -> ", |
|
"aa->", |
|
((5, 5),), |
|
(), |
|
), |
|
( |
|
|
|
" one two , three four->two four ", |
|
"ab,cd->bd", |
|
((2, 3), (4, 5)), |
|
(3, 5), |
|
), |
|
|
|
|
|
( |
|
|
|
"i middle i -> middle", |
|
"aba->b", |
|
((5, 10, 5),), |
|
(10,), |
|
), |
|
( |
|
|
|
"i ... i -> ...", |
|
"a...a->...", |
|
((5, 3, 2, 1, 4, 5),), |
|
(3, 2, 1, 4), |
|
), |
|
( |
|
|
|
"i ... i -> i ...", |
|
"a...a->a...", |
|
((5, 3, 2, 1, 4, 5),), |
|
(5, 3, 2, 1, 4), |
|
), |
|
( |
|
|
|
"one one one -> one", |
|
"aaa->a", |
|
((5, 5, 5),), |
|
(5,), |
|
), |
|
( |
|
|
|
"i j k -> j i k", |
|
"abc->bac", |
|
((1, 2, 3),), |
|
(2, 1, 3), |
|
), |
|
( |
|
|
|
"... -> ...", |
|
"...->...", |
|
((5, 4, 3, 2, 1),), |
|
(5, 4, 3, 2, 1), |
|
), |
|
( |
|
|
|
"..., ..., ... -> ...", |
|
"...,...,...->...", |
|
((3, 2), (3, 2), (3, 2)), |
|
(3, 2), |
|
), |
|
( |
|
|
|
"index ->", |
|
"a->", |
|
((10,)), |
|
(()), |
|
), |
|
] |
|
|
|
|
|
def test_layer(): |
|
for backend in collect_test_backends(layers=True, symbolic=False): |
|
if backend.framework_name in ["tensorflow", "torch", "oneflow", "paddle"]: |
|
layer_type = backend.layers().EinMix |
|
for args, in_shape, out_shape in test_layer_cases: |
|
layer = args(layer_type) |
|
print("Running", layer.einsum_pattern, "for", backend.framework_name) |
|
input = np.random.uniform(size=in_shape).astype("float32") |
|
input_framework = backend.from_numpy(input) |
|
output_framework = layer(input_framework) |
|
output = backend.to_numpy(output_framework) |
|
assert output.shape == out_shape |
|
|
|
|
|
valid_backends_functional = [ |
|
"tensorflow", |
|
"torch", |
|
"jax", |
|
"numpy", |
|
"oneflow", |
|
"cupy", |
|
"tensorflow.keras", |
|
"paddle", |
|
"pytensor", |
|
] |
|
|
|
|
|
def test_functional(): |
|
|
|
backends = filter(lambda x: x.framework_name in valid_backends_functional, collect_test_backends()) |
|
for backend in backends: |
|
for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases: |
|
print(f"Running '{einops_pattern}' for {backend.framework_name}") |
|
|
|
|
|
predicted_pattern = _compactify_pattern_for_einsum(einops_pattern) |
|
assert predicted_pattern == true_pattern |
|
|
|
|
|
rstate = np.random.RandomState(0) |
|
in_arrays = [rstate.uniform(size=shape).astype("float32") for shape in in_shapes] |
|
in_arrays_framework = [backend.from_numpy(array) for array in in_arrays] |
|
|
|
|
|
|
|
for do_manual_call in [True, False]: |
|
|
|
if do_manual_call: |
|
out_array = backend.einsum(predicted_pattern, *in_arrays_framework) |
|
else: |
|
out_array = einsum(*in_arrays_framework, einops_pattern) |
|
|
|
|
|
if tuple(out_array.shape) != out_shape: |
|
raise ValueError(f"Expected output shape {out_shape} but got {out_array.shape}") |
|
|
|
|
|
true_out_array = np.einsum(true_pattern, *in_arrays) |
|
predicted_out_array = backend.to_numpy(out_array) |
|
np.testing.assert_array_almost_equal(predicted_out_array, true_out_array, decimal=5) |
|
|
|
|
|
def test_functional_symbolic(): |
|
backends = filter( |
|
lambda x: x.framework_name in valid_backends_functional, collect_test_backends(symbolic=True, layers=False) |
|
) |
|
for backend in backends: |
|
for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases: |
|
print(f"Running '{einops_pattern}' for symbolic {backend.framework_name}") |
|
|
|
predicted_pattern = _compactify_pattern_for_einsum(einops_pattern) |
|
assert predicted_pattern == true_pattern |
|
|
|
rstate = np.random.RandomState(0) |
|
in_syms = [backend.create_symbol(in_shape) for in_shape in in_shapes] |
|
in_data = [rstate.uniform(size=in_shape).astype("float32") for in_shape in in_shapes] |
|
|
|
expected_out_data = np.einsum(true_pattern, *in_data) |
|
|
|
for do_manual_call in [True, False]: |
|
if do_manual_call: |
|
predicted_out_symbol = backend.einsum(predicted_pattern, *in_syms) |
|
else: |
|
predicted_out_symbol = einsum(*in_syms, einops_pattern) |
|
|
|
predicted_out_data = backend.eval_symbol( |
|
predicted_out_symbol, |
|
list(zip(in_syms, in_data)), |
|
) |
|
if predicted_out_data.shape != out_shape: |
|
raise ValueError(f"Expected output shape {out_shape} but got {predicted_out_data.shape}") |
|
np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5) |
|
|
|
|
|
def test_functional_errors(): |
|
|
|
|
|
|
|
rstate = np.random.RandomState(0) |
|
|
|
def create_tensor(*shape): |
|
return rstate.uniform(size=shape).astype("float32") |
|
|
|
|
|
with pytest.raises(NotImplementedError, match="^Singleton"): |
|
einsum( |
|
create_tensor(5, 1), |
|
"i () -> i", |
|
) |
|
|
|
|
|
with pytest.raises(NotImplementedError, match="^Shape rearrangement"): |
|
einsum( |
|
create_tensor(5, 1), |
|
"a b -> (a b)", |
|
) |
|
|
|
with pytest.raises(NotImplementedError, match="^Shape rearrangement"): |
|
einsum( |
|
create_tensor(10, 1), |
|
"(a b) -> a b", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(NotImplementedError, match="^Anonymous axes"): |
|
einsum( |
|
create_tensor(5, 1), |
|
"i 2 -> i", |
|
) |
|
|
|
|
|
with pytest.raises(EinopsError, match="^Invalid axis identifier"): |
|
einsum( |
|
create_tensor(5, 1), |
|
"i 2j -> i", |
|
) |
|
|
|
|
|
with pytest.raises(ValueError, match="^Einsum pattern"): |
|
einsum( |
|
create_tensor(5, 3, 2), |
|
"i j k", |
|
) |
|
|
|
|
|
with pytest.raises(RuntimeError, match="^Too many axes"): |
|
einsum( |
|
create_tensor(1), |
|
" ".join(string.ascii_letters) + " extra ->", |
|
) |
|
|
|
|
|
with pytest.raises(RuntimeError, match="^Unknown axis"): |
|
einsum( |
|
create_tensor(5, 1), |
|
"i j -> k", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="^The last argument"): |
|
einsum( |
|
"i j k -> i", |
|
create_tensor(5, 4, 3), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="^`einops.einsum` takes"): |
|
einsum( |
|
"i j k -> i", |
|
) |
|
with pytest.raises(ValueError, match="^`einops.einsum` takes"): |
|
einsum( |
|
create_tensor(5, 1), |
|
) |
|
|
|
|
|
|