File size: 4,389 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import pytest

from einops import EinopsError
from einops.parsing import ParsedExpression, AnonymousAxis, _ellipsis

__author__ = "Alex Rogozhnikov"


class AnonymousAxisPlaceholder:
    def __init__(self, value: int):
        self.value = value
        assert isinstance(self.value, int)

    def __eq__(self, other):
        return isinstance(other, AnonymousAxis) and self.value == other.value


def test_anonymous_axes():
    a, b = AnonymousAxis("2"), AnonymousAxis("2")
    assert a != b
    c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3)
    assert a == c and b == c
    assert a != d and b != d
    assert [a, 2, b] == [c, 2, c]


def test_elementary_axis_name():
    for name in [
        "a",
        "b",
        "h",
        "dx",
        "h1",
        "zz",
        "i9123",
        "somelongname",
        "Alex",
        "camelCase",
        "u_n_d_e_r_score",
        "unreasonablyLongAxisName",
    ]:
        assert ParsedExpression.check_axis_name(name)

    for name in ["", "2b", "12", "_startWithUnderscore", "endWithUnderscore_", "_", "...", _ellipsis]:
        assert not ParsedExpression.check_axis_name(name)


def test_invalid_expressions():
    # double ellipsis should raise an error
    ParsedExpression("... a b c d")
    with pytest.raises(EinopsError):
        ParsedExpression("... a b c d ...")
    with pytest.raises(EinopsError):
        ParsedExpression("... a b c (d ...)")
    with pytest.raises(EinopsError):
        ParsedExpression("(... a) b c (d ...)")

    # double/missing/enclosed parenthesis
    ParsedExpression("(a) b c (d ...)")
    with pytest.raises(EinopsError):
        ParsedExpression("(a)) b c (d ...)")
    with pytest.raises(EinopsError):
        ParsedExpression("(a b c (d ...)")
    with pytest.raises(EinopsError):
        ParsedExpression("(a) (()) b c (d ...)")
    with pytest.raises(EinopsError):
        ParsedExpression("(a) ((b c) (d ...))")

    # invalid identifiers
    ParsedExpression("camelCase under_scored cApiTaLs ß ...")
    with pytest.raises(EinopsError):
        ParsedExpression("1a")
    with pytest.raises(EinopsError):
        ParsedExpression("_pre")
    with pytest.raises(EinopsError):
        ParsedExpression("...pre")
    with pytest.raises(EinopsError):
        ParsedExpression("pre...")


def test_parse_expression():
    parsed = ParsedExpression("a1  b1   c1    d1")
    assert parsed.identifiers == {"a1", "b1", "c1", "d1"}
    assert parsed.composition == [["a1"], ["b1"], ["c1"], ["d1"]]
    assert not parsed.has_non_unitary_anonymous_axes
    assert not parsed.has_ellipsis

    parsed = ParsedExpression("() () () ()")
    assert parsed.identifiers == set()
    assert parsed.composition == [[], [], [], []]
    assert not parsed.has_non_unitary_anonymous_axes
    assert not parsed.has_ellipsis

    parsed = ParsedExpression("1 1 1 ()")
    assert parsed.identifiers == set()
    assert parsed.composition == [[], [], [], []]
    assert not parsed.has_non_unitary_anonymous_axes
    assert not parsed.has_ellipsis

    aap = AnonymousAxisPlaceholder

    parsed = ParsedExpression("5 (3 4)")
    assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5}
    assert parsed.composition == [[aap(5)], [aap(3), aap(4)]]
    assert parsed.has_non_unitary_anonymous_axes
    assert not parsed.has_ellipsis

    parsed = ParsedExpression("5 1 (1 4) 1")
    assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5}
    assert parsed.composition == [[aap(5)], [], [aap(4)], []]

    parsed = ParsedExpression("name1 ... a1 12 (name2 14)")
    assert len(parsed.identifiers) == 6
    assert parsed.identifiers.difference({"name1", _ellipsis, "a1", "name2"}).__len__() == 2
    assert parsed.composition == [["name1"], _ellipsis, ["a1"], [aap(12)], ["name2", aap(14)]]
    assert parsed.has_non_unitary_anonymous_axes
    assert parsed.has_ellipsis
    assert not parsed.has_ellipsis_parenthesized

    parsed = ParsedExpression("(name1 ... a1 12) name2 14")
    assert len(parsed.identifiers) == 6
    assert parsed.identifiers.difference({"name1", _ellipsis, "a1", "name2"}).__len__() == 2
    assert parsed.composition == [["name1", _ellipsis, "a1", aap(12)], ["name2"], [aap(14)]]
    assert parsed.has_non_unitary_anonymous_axes
    assert parsed.has_ellipsis
    assert parsed.has_ellipsis_parenthesized