File size: 10,977 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 |
from typing import Any, Callable
from einops.tests import collect_test_backends
from einops.einops import _compactify_pattern_for_einsum, einsum, EinopsError
import numpy as np
import pytest
import string
class Arguments:
def __init__(self, *args: Any, **kargs: Any):
self.args = args
self.kwargs = kargs
def __call__(self, function: Callable):
return function(*self.args, **self.kwargs)
test_layer_cases = [
(
Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape=None, c_out=13, c_in=12),
(2, 12, 3, 4),
(4, 13, 3, 2),
),
(
Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape="c_out", c_out=13, c_in=12),
(2, 12, 3, 4),
(4, 13, 3, 2),
),
(
Arguments("b c_in h w -> w c_in h b", "", bias_shape=None, c_in=12),
(2, 12, 3, 4),
(4, 12, 3, 2),
),
(
Arguments("b c_in h w -> b c_out", "c_in h w c_out", bias_shape=None, c_in=12, h=3, w=4, c_out=5),
(2, 12, 3, 4),
(2, 5),
),
(
Arguments("b t head c_in -> b t head c_out", "head c_in c_out", bias_shape=None, head=4, c_in=5, c_out=6),
(2, 3, 4, 5),
(2, 3, 4, 6),
),
]
# Each of the form:
# (Arguments, true_einsum_pattern, in_shapes, out_shape)
test_functional_cases = [
(
# Basic:
"b c h w, b w -> b h",
"abcd,ad->ac",
((2, 3, 4, 5), (2, 5)),
(2, 4),
),
(
# Three tensors:
"b c h w, b w, b c -> b h",
"abcd,ad,ab->ac",
((2, 3, 40, 5), (2, 5), (2, 3)),
(2, 40),
),
(
# Ellipsis, and full names:
"... one two three, three four five -> ... two five",
"...abc,cde->...be",
((32, 5, 2, 3, 4), (4, 5, 6)),
(32, 5, 3, 6),
),
(
# Ellipsis at the end:
"one two three ..., three four five -> two five ...",
"abc...,cde->be...",
((2, 3, 4, 32, 5), (4, 5, 6)),
(3, 6, 32, 5),
),
(
# Ellipsis on multiple tensors:
"... one two three, ... three four five -> ... two five",
"...abc,...cde->...be",
((32, 5, 2, 3, 4), (32, 5, 4, 5, 6)),
(32, 5, 3, 6),
),
(
# One tensor, and underscores:
"first_tensor second_tensor -> first_tensor",
"ab->a",
((5, 4),),
(5,),
),
(
# Trace (repeated index)
"i i -> ",
"aa->",
((5, 5),),
(),
),
(
# Too many spaces in string:
" one two , three four->two four ",
"ab,cd->bd",
((2, 3), (4, 5)),
(3, 5),
),
# The following tests were inspired by numpy's einsum tests
# https://github.com/numpy/numpy/blob/v1.23.0/numpy/core/tests/test_einsum.py
(
# Trace with other indices
"i middle i -> middle",
"aba->b",
((5, 10, 5),),
(10,),
),
(
# Ellipsis in the middle:
"i ... i -> ...",
"a...a->...",
((5, 3, 2, 1, 4, 5),),
(3, 2, 1, 4),
),
(
# Product of first and last axes:
"i ... i -> i ...",
"a...a->a...",
((5, 3, 2, 1, 4, 5),),
(5, 3, 2, 1, 4),
),
(
# Triple diagonal
"one one one -> one",
"aaa->a",
((5, 5, 5),),
(5,),
),
(
# Axis swap:
"i j k -> j i k",
"abc->bac",
((1, 2, 3),),
(2, 1, 3),
),
(
# Identity:
"... -> ...",
"...->...",
((5, 4, 3, 2, 1),),
(5, 4, 3, 2, 1),
),
(
# Elementwise product of three tensors
"..., ..., ... -> ...",
"...,...,...->...",
((3, 2), (3, 2), (3, 2)),
(3, 2),
),
(
# Basic summation:
"index ->",
"a->",
((10,)),
(()),
),
]
def test_layer():
for backend in collect_test_backends(layers=True, symbolic=False):
if backend.framework_name in ["tensorflow", "torch", "oneflow", "paddle"]:
layer_type = backend.layers().EinMix
for args, in_shape, out_shape in test_layer_cases:
layer = args(layer_type)
print("Running", layer.einsum_pattern, "for", backend.framework_name)
input = np.random.uniform(size=in_shape).astype("float32")
input_framework = backend.from_numpy(input)
output_framework = layer(input_framework)
output = backend.to_numpy(output_framework)
assert output.shape == out_shape
valid_backends_functional = [
"tensorflow",
"torch",
"jax",
"numpy",
"oneflow",
"cupy",
"tensorflow.keras",
"paddle",
"pytensor",
]
def test_functional():
# Functional tests:
backends = filter(lambda x: x.framework_name in valid_backends_functional, collect_test_backends())
for backend in backends:
for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
print(f"Running '{einops_pattern}' for {backend.framework_name}")
# Create pattern:
predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
assert predicted_pattern == true_pattern
# Generate example data:
rstate = np.random.RandomState(0)
in_arrays = [rstate.uniform(size=shape).astype("float32") for shape in in_shapes]
in_arrays_framework = [backend.from_numpy(array) for array in in_arrays]
# Loop over whether we call it manually with the backend,
# or whether we use `einops.einsum`.
for do_manual_call in [True, False]:
# Actually run einsum:
if do_manual_call:
out_array = backend.einsum(predicted_pattern, *in_arrays_framework)
else:
out_array = einsum(*in_arrays_framework, einops_pattern)
# Check shape:
if tuple(out_array.shape) != out_shape:
raise ValueError(f"Expected output shape {out_shape} but got {out_array.shape}")
# Check values:
true_out_array = np.einsum(true_pattern, *in_arrays)
predicted_out_array = backend.to_numpy(out_array)
np.testing.assert_array_almost_equal(predicted_out_array, true_out_array, decimal=5)
def test_functional_symbolic():
backends = filter(
lambda x: x.framework_name in valid_backends_functional, collect_test_backends(symbolic=True, layers=False)
)
for backend in backends:
for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
print(f"Running '{einops_pattern}' for symbolic {backend.framework_name}")
# Create pattern:
predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
assert predicted_pattern == true_pattern
rstate = np.random.RandomState(0)
in_syms = [backend.create_symbol(in_shape) for in_shape in in_shapes]
in_data = [rstate.uniform(size=in_shape).astype("float32") for in_shape in in_shapes]
expected_out_data = np.einsum(true_pattern, *in_data)
for do_manual_call in [True, False]:
if do_manual_call:
predicted_out_symbol = backend.einsum(predicted_pattern, *in_syms)
else:
predicted_out_symbol = einsum(*in_syms, einops_pattern)
predicted_out_data = backend.eval_symbol(
predicted_out_symbol,
list(zip(in_syms, in_data)),
)
if predicted_out_data.shape != out_shape:
raise ValueError(f"Expected output shape {out_shape} but got {predicted_out_data.shape}")
np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5)
def test_functional_errors():
# Specific backend does not matter, as errors are raised
# during the pattern creation.
rstate = np.random.RandomState(0)
def create_tensor(*shape):
return rstate.uniform(size=shape).astype("float32")
# raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
with pytest.raises(NotImplementedError, match="^Singleton"):
einsum(
create_tensor(5, 1),
"i () -> i",
)
# raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
einsum(
create_tensor(5, 1),
"a b -> (a b)",
)
with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
einsum(
create_tensor(10, 1),
"(a b) -> a b",
)
# raise RuntimeError("Encountered empty axis name in einsum.")
# raise RuntimeError("Axis name in einsum must be a string.")
# ^ Not tested, these are just a failsafe in case an unexpected error occurs.
# raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
with pytest.raises(NotImplementedError, match="^Anonymous axes"):
einsum(
create_tensor(5, 1),
"i 2 -> i",
)
# ParsedExpression error:
with pytest.raises(EinopsError, match="^Invalid axis identifier"):
einsum(
create_tensor(5, 1),
"i 2j -> i",
)
# raise ValueError("Einsum pattern must contain '->'.")
with pytest.raises(ValueError, match="^Einsum pattern"):
einsum(
create_tensor(5, 3, 2),
"i j k",
)
# raise RuntimeError("Too many axes in einsum.")
with pytest.raises(RuntimeError, match="^Too many axes"):
einsum(
create_tensor(1),
" ".join(string.ascii_letters) + " extra ->",
)
# raise RuntimeError("Unknown axis on right side of einsum.")
with pytest.raises(RuntimeError, match="^Unknown axis"):
einsum(
create_tensor(5, 1),
"i j -> k",
)
# raise ValueError(
# "The last argument passed to `einops.einsum` must be a string,"
# " representing the einsum pattern."
# )
with pytest.raises(ValueError, match="^The last argument"):
einsum(
"i j k -> i",
create_tensor(5, 4, 3),
)
# raise ValueError(
# "`einops.einsum` takes at minimum two arguments: the tensors,"
# " followed by the pattern."
# )
with pytest.raises(ValueError, match="^`einops.einsum` takes"):
einsum(
"i j k -> i",
)
with pytest.raises(ValueError, match="^`einops.einsum` takes"):
einsum(
create_tensor(5, 1),
)
# TODO: Include check for giving normal einsum pattern rather than einops.
|