|
import dataclasses |
|
import typing |
|
|
|
import numpy as np |
|
import pytest |
|
|
|
from einops import EinopsError, asnumpy, pack, unpack |
|
from einops.tests import collect_test_backends |
|
|
|
|
|
def pack_unpack(xs, pattern): |
|
x, ps = pack(xs, pattern) |
|
unpacked = unpack(xs, ps, pattern) |
|
assert len(unpacked) == len(xs) |
|
for a, b in zip(unpacked, xs): |
|
assert np.allclose(asnumpy(a), asnumpy(b)) |
|
|
|
|
|
def unpack_and_pack(x, ps, pattern: str): |
|
unpacked = unpack(x, ps, pattern) |
|
packed, ps2 = pack(unpacked, pattern=pattern) |
|
|
|
assert np.allclose(asnumpy(packed), asnumpy(x)) |
|
return unpacked |
|
|
|
|
|
def unpack_and_pack_against_numpy(x, ps, pattern: str): |
|
capturer_backend = CaptureException() |
|
capturer_numpy = CaptureException() |
|
|
|
with capturer_backend: |
|
unpacked = unpack(x, ps, pattern) |
|
packed, ps2 = pack(unpacked, pattern=pattern) |
|
|
|
with capturer_numpy: |
|
x_np = asnumpy(x) |
|
unpacked_np = unpack(x_np, ps, pattern) |
|
packed_np, ps3 = pack(unpacked_np, pattern=pattern) |
|
|
|
assert type(capturer_numpy.exception) == type(capturer_backend.exception) |
|
if capturer_numpy.exception is not None: |
|
|
|
return |
|
else: |
|
|
|
assert np.allclose(asnumpy(packed), asnumpy(x)) |
|
assert np.allclose(asnumpy(packed_np), asnumpy(x)) |
|
assert len(unpacked) == len(unpacked_np) |
|
for a, b in zip(unpacked, unpacked_np): |
|
assert np.allclose(asnumpy(a), b) |
|
|
|
|
|
class CaptureException: |
|
def __enter__(self): |
|
self.exception = None |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
self.exception = exc_val |
|
return True |
|
|
|
|
|
def test_numpy_trivial(H=13, W=17): |
|
def rand(*shape): |
|
return np.random.random(shape) |
|
|
|
def check(a, b): |
|
assert a.dtype == b.dtype |
|
assert a.shape == b.shape |
|
assert np.all(a == b) |
|
|
|
r, g, b = rand(3, H, W) |
|
embeddings = rand(H, W, 32) |
|
|
|
check( |
|
np.stack([r, g, b], axis=2), |
|
pack([r, g, b], "h w *")[0], |
|
) |
|
check( |
|
np.stack([r, g, b], axis=1), |
|
pack([r, g, b], "h * w")[0], |
|
) |
|
check( |
|
np.stack([r, g, b], axis=0), |
|
pack([r, g, b], "* h w")[0], |
|
) |
|
|
|
check( |
|
np.concatenate([r, g, b], axis=1), |
|
pack([r, g, b], "h *")[0], |
|
) |
|
check( |
|
np.concatenate([r, g, b], axis=0), |
|
pack([r, g, b], "* w")[0], |
|
) |
|
|
|
i = np.index_exp[:, :, None] |
|
check( |
|
np.concatenate([r[i], g[i], b[i], embeddings], axis=2), |
|
pack([r, g, b, embeddings], "h w *")[0], |
|
) |
|
|
|
with pytest.raises(EinopsError): |
|
pack([r, g, b, embeddings], "h w nonexisting_axis *") |
|
|
|
pack([r, g, b], "some_name_for_H some_name_for_w1 *") |
|
|
|
with pytest.raises(EinopsError): |
|
pack([r, g, b, embeddings], "h _w *") |
|
with pytest.raises(EinopsError): |
|
pack([r, g, b, embeddings], "h_ w *") |
|
with pytest.raises(EinopsError): |
|
pack([r, g, b, embeddings], "1h_ w *") |
|
with pytest.raises(EinopsError): |
|
pack([r, g, b, embeddings], "1 w *") |
|
with pytest.raises(EinopsError): |
|
pack([r, g, b, embeddings], "h h *") |
|
|
|
pack([r, g, b, embeddings], "h H *") |
|
|
|
|
|
@dataclasses.dataclass |
|
class UnpackTestCase: |
|
shape: typing.Tuple[int, ...] |
|
pattern: str |
|
|
|
def dim(self): |
|
return self.pattern.split().index("*") |
|
|
|
def selfcheck(self): |
|
assert self.shape[self.dim()] == 5 |
|
|
|
|
|
cases = [ |
|
|
|
|
|
UnpackTestCase((5,), "*"), |
|
UnpackTestCase((5, 7), "* seven"), |
|
UnpackTestCase((7, 5), "seven *"), |
|
UnpackTestCase((5, 3, 4), "* three four"), |
|
UnpackTestCase((4, 5, 3), "four * three"), |
|
UnpackTestCase((3, 4, 5), "three four *"), |
|
] |
|
|
|
|
|
def test_pack_unpack_with_numpy(): |
|
case: UnpackTestCase |
|
|
|
for case in cases: |
|
shape = case.shape |
|
pattern = case.pattern |
|
|
|
x = np.random.random(shape) |
|
|
|
unpack_and_pack(x, [[2], [1], [2]], pattern) |
|
|
|
with pytest.raises(BaseException): |
|
unpack_and_pack(x, [[2], [1], [2]], pattern + " non_existent_axis") |
|
with pytest.raises(BaseException): |
|
unpack_and_pack(x, [[2], [1], [1]], pattern) |
|
with pytest.raises(BaseException): |
|
unpack_and_pack(x, [[4], [1], [1]], pattern) |
|
|
|
unpack_and_pack(x, [[2], [1], [-1]], pattern) |
|
unpack_and_pack(x, [[2], [-1], [2]], pattern) |
|
unpack_and_pack(x, [[-1], [1], [2]], pattern) |
|
_, _, last = unpack_and_pack(x, [[2], [3], [-1]], pattern) |
|
assert last.shape[case.dim()] == 0 |
|
|
|
with pytest.raises(BaseException): |
|
unpack(x, [[2], [4], [-1]], pattern) |
|
|
|
|
|
|
|
with pytest.raises(BaseException): |
|
unpack(x, [[-1], [1], [5]], pattern) |
|
|
|
|
|
rs = unpack_and_pack(x, [[1, 2], [1, 1], [-1, 1]], pattern) |
|
assert all(len(r.shape) == len(x.shape) + 1 for r in rs) |
|
rs = unpack_and_pack(x, [[1, 2], [1, -1], [1, 1]], pattern) |
|
assert all(len(r.shape) == len(x.shape) + 1 for r in rs) |
|
rs = unpack_and_pack(x, [[2, -1], [1, 2], [1, 1]], pattern) |
|
assert all(len(r.shape) == len(x.shape) + 1 for r in rs) |
|
|
|
|
|
with pytest.raises(BaseException): |
|
unpack(x, [[-1, 2], [1], [5]], pattern) |
|
with pytest.raises(BaseException): |
|
unpack(x, [[2, 2], [2], [5, -1]], pattern) |
|
|
|
|
|
with pytest.raises(BaseException): |
|
unpack(x, [[2, 1], [1], [3, -1]], pattern) |
|
with pytest.raises(BaseException): |
|
unpack(x, [[2, 1], [3, -1], [1]], pattern) |
|
with pytest.raises(BaseException): |
|
unpack(x, [[3, -1], [2, 1], [1]], pattern) |
|
|
|
|
|
unpack_and_pack(x, [[0], [5], [-1]], pattern) |
|
unpack_and_pack(x, [[0], [-1], [5]], pattern) |
|
unpack_and_pack(x, [[-1], [5], [0]], pattern) |
|
|
|
|
|
unpack_and_pack(x, [[2, -1], [1, 5]], pattern) |
|
|
|
|
|
def test_pack_unpack_against_numpy(): |
|
for backend in collect_test_backends(symbolic=False, layers=False): |
|
print(f"test packing against numpy for {backend.framework_name}") |
|
check_zero_len = True |
|
|
|
for case in cases: |
|
unpack_and_pack = unpack_and_pack_against_numpy |
|
shape = case.shape |
|
pattern = case.pattern |
|
|
|
x = np.random.random(shape) |
|
x = backend.from_numpy(x) |
|
|
|
unpack_and_pack(x, [[2], [1], [2]], pattern) |
|
|
|
with pytest.raises(BaseException): |
|
unpack(x, [[2], [1], [1]], pattern) |
|
|
|
with pytest.raises(BaseException): |
|
unpack(x, [[4], [1], [1]], pattern) |
|
|
|
unpack_and_pack(x, [[2], [1], [-1]], pattern) |
|
unpack_and_pack(x, [[2], [-1], [2]], pattern) |
|
unpack_and_pack(x, [[-1], [1], [2]], pattern) |
|
|
|
|
|
with pytest.raises(BaseException): |
|
unpack(x, [[2], [4], [-1]], pattern) |
|
|
|
|
|
|
|
with pytest.raises(BaseException): |
|
unpack(x, [[-1], [1], [5]], pattern) |
|
|
|
|
|
unpack_and_pack(x, [[1, 2], [1, 1], [-1, 1]], pattern) |
|
unpack_and_pack(x, [[1, 2], [1, -1], [1, 1]], pattern) |
|
unpack_and_pack(x, [[2, -1], [1, 2], [1, 1]], pattern) |
|
|
|
|
|
with pytest.raises(BaseException): |
|
unpack(x, [[-1, 2], [1], [5]], pattern) |
|
with pytest.raises(BaseException): |
|
unpack(x, [[2, 2], [2], [5, -1]], pattern) |
|
|
|
|
|
with pytest.raises(BaseException): |
|
unpack(x, [[2, 1], [1], [3, -1]], pattern) |
|
with pytest.raises(BaseException): |
|
unpack(x, [[2, 1], [3, -1], [1]], pattern) |
|
with pytest.raises(BaseException): |
|
unpack(x, [[3, -1], [2, 1], [1]], pattern) |
|
|
|
if check_zero_len: |
|
|
|
unpack_and_pack(x, [[2], [3], [-1]], pattern) |
|
unpack_and_pack(x, [[0], [5], [-1]], pattern) |
|
unpack_and_pack(x, [[0], [-1], [5]], pattern) |
|
unpack_and_pack(x, [[-1], [5], [0]], pattern) |
|
|
|
|
|
unpack_and_pack(x, [[2, -1], [1, 5]], pattern) |
|
|
|
|
|
def test_pack_unpack_array_api(): |
|
from einops import array_api as AA |
|
import numpy as xp |
|
|
|
if xp.__version__ < "2.0.0": |
|
pytest.skip() |
|
|
|
for case in cases: |
|
shape = case.shape |
|
pattern = case.pattern |
|
x_np = np.random.random(shape) |
|
x_xp = xp.from_dlpack(x_np) |
|
|
|
for ps in [ |
|
[[2], [1], [2]], |
|
[[1], [1], [-1]], |
|
[[1], [1], [-1, 3]], |
|
[[2, 1], [1, 1, 1], [-1]], |
|
]: |
|
x_np_split = unpack(x_np, ps, pattern) |
|
x_xp_split = AA.unpack(x_xp, ps, pattern) |
|
for a, b in zip(x_np_split, x_xp_split): |
|
assert np.allclose(a, AA.asnumpy(b + 0)) |
|
|
|
x_agg_np, ps1 = pack(x_np_split, pattern) |
|
x_agg_xp, ps2 = AA.pack(x_xp_split, pattern) |
|
assert ps1 == ps2 |
|
assert np.allclose(x_agg_np, AA.asnumpy(x_agg_xp)) |
|
|
|
for ps in [ |
|
[[2, 3]], |
|
[[1], [5]], |
|
[[1], [5], [-1]], |
|
[[1], [2, 3]], |
|
[[1], [5], [-1, 2]], |
|
]: |
|
with pytest.raises(BaseException): |
|
unpack(x_np, ps, pattern) |
|
|