jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
from typing import Any, Callable
from einops.tests import collect_test_backends
from einops.einops import _compactify_pattern_for_einsum, einsum, EinopsError
import numpy as np
import pytest
import string
class Arguments:
def __init__(self, *args: Any, **kargs: Any):
self.args = args
self.kwargs = kargs
def __call__(self, function: Callable):
return function(*self.args, **self.kwargs)
test_layer_cases = [
(
Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape=None, c_out=13, c_in=12),
(2, 12, 3, 4),
(4, 13, 3, 2),
),
(
Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape="c_out", c_out=13, c_in=12),
(2, 12, 3, 4),
(4, 13, 3, 2),
),
(
Arguments("b c_in h w -> w c_in h b", "", bias_shape=None, c_in=12),
(2, 12, 3, 4),
(4, 12, 3, 2),
),
(
Arguments("b c_in h w -> b c_out", "c_in h w c_out", bias_shape=None, c_in=12, h=3, w=4, c_out=5),
(2, 12, 3, 4),
(2, 5),
),
(
Arguments("b t head c_in -> b t head c_out", "head c_in c_out", bias_shape=None, head=4, c_in=5, c_out=6),
(2, 3, 4, 5),
(2, 3, 4, 6),
),
]
# Each of the form:
# (Arguments, true_einsum_pattern, in_shapes, out_shape)
test_functional_cases = [
(
# Basic:
"b c h w, b w -> b h",
"abcd,ad->ac",
((2, 3, 4, 5), (2, 5)),
(2, 4),
),
(
# Three tensors:
"b c h w, b w, b c -> b h",
"abcd,ad,ab->ac",
((2, 3, 40, 5), (2, 5), (2, 3)),
(2, 40),
),
(
# Ellipsis, and full names:
"... one two three, three four five -> ... two five",
"...abc,cde->...be",
((32, 5, 2, 3, 4), (4, 5, 6)),
(32, 5, 3, 6),
),
(
# Ellipsis at the end:
"one two three ..., three four five -> two five ...",
"abc...,cde->be...",
((2, 3, 4, 32, 5), (4, 5, 6)),
(3, 6, 32, 5),
),
(
# Ellipsis on multiple tensors:
"... one two three, ... three four five -> ... two five",
"...abc,...cde->...be",
((32, 5, 2, 3, 4), (32, 5, 4, 5, 6)),
(32, 5, 3, 6),
),
(
# One tensor, and underscores:
"first_tensor second_tensor -> first_tensor",
"ab->a",
((5, 4),),
(5,),
),
(
# Trace (repeated index)
"i i -> ",
"aa->",
((5, 5),),
(),
),
(
# Too many spaces in string:
" one two , three four->two four ",
"ab,cd->bd",
((2, 3), (4, 5)),
(3, 5),
),
# The following tests were inspired by numpy's einsum tests
# https://github.com/numpy/numpy/blob/v1.23.0/numpy/core/tests/test_einsum.py
(
# Trace with other indices
"i middle i -> middle",
"aba->b",
((5, 10, 5),),
(10,),
),
(
# Ellipsis in the middle:
"i ... i -> ...",
"a...a->...",
((5, 3, 2, 1, 4, 5),),
(3, 2, 1, 4),
),
(
# Product of first and last axes:
"i ... i -> i ...",
"a...a->a...",
((5, 3, 2, 1, 4, 5),),
(5, 3, 2, 1, 4),
),
(
# Triple diagonal
"one one one -> one",
"aaa->a",
((5, 5, 5),),
(5,),
),
(
# Axis swap:
"i j k -> j i k",
"abc->bac",
((1, 2, 3),),
(2, 1, 3),
),
(
# Identity:
"... -> ...",
"...->...",
((5, 4, 3, 2, 1),),
(5, 4, 3, 2, 1),
),
(
# Elementwise product of three tensors
"..., ..., ... -> ...",
"...,...,...->...",
((3, 2), (3, 2), (3, 2)),
(3, 2),
),
(
# Basic summation:
"index ->",
"a->",
((10,)),
(()),
),
]
def test_layer():
for backend in collect_test_backends(layers=True, symbolic=False):
if backend.framework_name in ["tensorflow", "torch", "oneflow", "paddle"]:
layer_type = backend.layers().EinMix
for args, in_shape, out_shape in test_layer_cases:
layer = args(layer_type)
print("Running", layer.einsum_pattern, "for", backend.framework_name)
input = np.random.uniform(size=in_shape).astype("float32")
input_framework = backend.from_numpy(input)
output_framework = layer(input_framework)
output = backend.to_numpy(output_framework)
assert output.shape == out_shape
valid_backends_functional = [
"tensorflow",
"torch",
"jax",
"numpy",
"oneflow",
"cupy",
"tensorflow.keras",
"paddle",
"pytensor",
]
def test_functional():
# Functional tests:
backends = filter(lambda x: x.framework_name in valid_backends_functional, collect_test_backends())
for backend in backends:
for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
print(f"Running '{einops_pattern}' for {backend.framework_name}")
# Create pattern:
predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
assert predicted_pattern == true_pattern
# Generate example data:
rstate = np.random.RandomState(0)
in_arrays = [rstate.uniform(size=shape).astype("float32") for shape in in_shapes]
in_arrays_framework = [backend.from_numpy(array) for array in in_arrays]
# Loop over whether we call it manually with the backend,
# or whether we use `einops.einsum`.
for do_manual_call in [True, False]:
# Actually run einsum:
if do_manual_call:
out_array = backend.einsum(predicted_pattern, *in_arrays_framework)
else:
out_array = einsum(*in_arrays_framework, einops_pattern)
# Check shape:
if tuple(out_array.shape) != out_shape:
raise ValueError(f"Expected output shape {out_shape} but got {out_array.shape}")
# Check values:
true_out_array = np.einsum(true_pattern, *in_arrays)
predicted_out_array = backend.to_numpy(out_array)
np.testing.assert_array_almost_equal(predicted_out_array, true_out_array, decimal=5)
def test_functional_symbolic():
backends = filter(
lambda x: x.framework_name in valid_backends_functional, collect_test_backends(symbolic=True, layers=False)
)
for backend in backends:
for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
print(f"Running '{einops_pattern}' for symbolic {backend.framework_name}")
# Create pattern:
predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
assert predicted_pattern == true_pattern
rstate = np.random.RandomState(0)
in_syms = [backend.create_symbol(in_shape) for in_shape in in_shapes]
in_data = [rstate.uniform(size=in_shape).astype("float32") for in_shape in in_shapes]
expected_out_data = np.einsum(true_pattern, *in_data)
for do_manual_call in [True, False]:
if do_manual_call:
predicted_out_symbol = backend.einsum(predicted_pattern, *in_syms)
else:
predicted_out_symbol = einsum(*in_syms, einops_pattern)
predicted_out_data = backend.eval_symbol(
predicted_out_symbol,
list(zip(in_syms, in_data)),
)
if predicted_out_data.shape != out_shape:
raise ValueError(f"Expected output shape {out_shape} but got {predicted_out_data.shape}")
np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5)
def test_functional_errors():
# Specific backend does not matter, as errors are raised
# during the pattern creation.
rstate = np.random.RandomState(0)
def create_tensor(*shape):
return rstate.uniform(size=shape).astype("float32")
# raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
with pytest.raises(NotImplementedError, match="^Singleton"):
einsum(
create_tensor(5, 1),
"i () -> i",
)
# raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
einsum(
create_tensor(5, 1),
"a b -> (a b)",
)
with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
einsum(
create_tensor(10, 1),
"(a b) -> a b",
)
# raise RuntimeError("Encountered empty axis name in einsum.")
# raise RuntimeError("Axis name in einsum must be a string.")
# ^ Not tested, these are just a failsafe in case an unexpected error occurs.
# raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
with pytest.raises(NotImplementedError, match="^Anonymous axes"):
einsum(
create_tensor(5, 1),
"i 2 -> i",
)
# ParsedExpression error:
with pytest.raises(EinopsError, match="^Invalid axis identifier"):
einsum(
create_tensor(5, 1),
"i 2j -> i",
)
# raise ValueError("Einsum pattern must contain '->'.")
with pytest.raises(ValueError, match="^Einsum pattern"):
einsum(
create_tensor(5, 3, 2),
"i j k",
)
# raise RuntimeError("Too many axes in einsum.")
with pytest.raises(RuntimeError, match="^Too many axes"):
einsum(
create_tensor(1),
" ".join(string.ascii_letters) + " extra ->",
)
# raise RuntimeError("Unknown axis on right side of einsum.")
with pytest.raises(RuntimeError, match="^Unknown axis"):
einsum(
create_tensor(5, 1),
"i j -> k",
)
# raise ValueError(
# "The last argument passed to `einops.einsum` must be a string,"
# " representing the einsum pattern."
# )
with pytest.raises(ValueError, match="^The last argument"):
einsum(
"i j k -> i",
create_tensor(5, 4, 3),
)
# raise ValueError(
# "`einops.einsum` takes at minimum two arguments: the tensors,"
# " followed by the pattern."
# )
with pytest.raises(ValueError, match="^`einops.einsum` takes"):
einsum(
"i j k -> i",
)
with pytest.raises(ValueError, match="^`einops.einsum` takes"):
einsum(
create_tensor(5, 1),
)
# TODO: Include check for giving normal einsum pattern rather than einops.