File size: 4,389 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import pytest
from einops import EinopsError
from einops.parsing import ParsedExpression, AnonymousAxis, _ellipsis
__author__ = "Alex Rogozhnikov"
class AnonymousAxisPlaceholder:
def __init__(self, value: int):
self.value = value
assert isinstance(self.value, int)
def __eq__(self, other):
return isinstance(other, AnonymousAxis) and self.value == other.value
def test_anonymous_axes():
a, b = AnonymousAxis("2"), AnonymousAxis("2")
assert a != b
c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3)
assert a == c and b == c
assert a != d and b != d
assert [a, 2, b] == [c, 2, c]
def test_elementary_axis_name():
for name in [
"a",
"b",
"h",
"dx",
"h1",
"zz",
"i9123",
"somelongname",
"Alex",
"camelCase",
"u_n_d_e_r_score",
"unreasonablyLongAxisName",
]:
assert ParsedExpression.check_axis_name(name)
for name in ["", "2b", "12", "_startWithUnderscore", "endWithUnderscore_", "_", "...", _ellipsis]:
assert not ParsedExpression.check_axis_name(name)
def test_invalid_expressions():
# double ellipsis should raise an error
ParsedExpression("... a b c d")
with pytest.raises(EinopsError):
ParsedExpression("... a b c d ...")
with pytest.raises(EinopsError):
ParsedExpression("... a b c (d ...)")
with pytest.raises(EinopsError):
ParsedExpression("(... a) b c (d ...)")
# double/missing/enclosed parenthesis
ParsedExpression("(a) b c (d ...)")
with pytest.raises(EinopsError):
ParsedExpression("(a)) b c (d ...)")
with pytest.raises(EinopsError):
ParsedExpression("(a b c (d ...)")
with pytest.raises(EinopsError):
ParsedExpression("(a) (()) b c (d ...)")
with pytest.raises(EinopsError):
ParsedExpression("(a) ((b c) (d ...))")
# invalid identifiers
ParsedExpression("camelCase under_scored cApiTaLs Γ ...")
with pytest.raises(EinopsError):
ParsedExpression("1a")
with pytest.raises(EinopsError):
ParsedExpression("_pre")
with pytest.raises(EinopsError):
ParsedExpression("...pre")
with pytest.raises(EinopsError):
ParsedExpression("pre...")
def test_parse_expression():
parsed = ParsedExpression("a1 b1 c1 d1")
assert parsed.identifiers == {"a1", "b1", "c1", "d1"}
assert parsed.composition == [["a1"], ["b1"], ["c1"], ["d1"]]
assert not parsed.has_non_unitary_anonymous_axes
assert not parsed.has_ellipsis
parsed = ParsedExpression("() () () ()")
assert parsed.identifiers == set()
assert parsed.composition == [[], [], [], []]
assert not parsed.has_non_unitary_anonymous_axes
assert not parsed.has_ellipsis
parsed = ParsedExpression("1 1 1 ()")
assert parsed.identifiers == set()
assert parsed.composition == [[], [], [], []]
assert not parsed.has_non_unitary_anonymous_axes
assert not parsed.has_ellipsis
aap = AnonymousAxisPlaceholder
parsed = ParsedExpression("5 (3 4)")
assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5}
assert parsed.composition == [[aap(5)], [aap(3), aap(4)]]
assert parsed.has_non_unitary_anonymous_axes
assert not parsed.has_ellipsis
parsed = ParsedExpression("5 1 (1 4) 1")
assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5}
assert parsed.composition == [[aap(5)], [], [aap(4)], []]
parsed = ParsedExpression("name1 ... a1 12 (name2 14)")
assert len(parsed.identifiers) == 6
assert parsed.identifiers.difference({"name1", _ellipsis, "a1", "name2"}).__len__() == 2
assert parsed.composition == [["name1"], _ellipsis, ["a1"], [aap(12)], ["name2", aap(14)]]
assert parsed.has_non_unitary_anonymous_axes
assert parsed.has_ellipsis
assert not parsed.has_ellipsis_parenthesized
parsed = ParsedExpression("(name1 ... a1 12) name2 14")
assert len(parsed.identifiers) == 6
assert parsed.identifiers.difference({"name1", _ellipsis, "a1", "name2"}).__len__() == 2
assert parsed.composition == [["name1", _ellipsis, "a1", aap(12)], ["name2"], [aap(14)]]
assert parsed.has_non_unitary_anonymous_axes
assert parsed.has_ellipsis
assert parsed.has_ellipsis_parenthesized
|