import pytest from einops import EinopsError from einops.parsing import ParsedExpression, AnonymousAxis, _ellipsis __author__ = "Alex Rogozhnikov" class AnonymousAxisPlaceholder: def __init__(self, value: int): self.value = value assert isinstance(self.value, int) def __eq__(self, other): return isinstance(other, AnonymousAxis) and self.value == other.value def test_anonymous_axes(): a, b = AnonymousAxis("2"), AnonymousAxis("2") assert a != b c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3) assert a == c and b == c assert a != d and b != d assert [a, 2, b] == [c, 2, c] def test_elementary_axis_name(): for name in [ "a", "b", "h", "dx", "h1", "zz", "i9123", "somelongname", "Alex", "camelCase", "u_n_d_e_r_score", "unreasonablyLongAxisName", ]: assert ParsedExpression.check_axis_name(name) for name in ["", "2b", "12", "_startWithUnderscore", "endWithUnderscore_", "_", "...", _ellipsis]: assert not ParsedExpression.check_axis_name(name) def test_invalid_expressions(): # double ellipsis should raise an error ParsedExpression("... a b c d") with pytest.raises(EinopsError): ParsedExpression("... a b c d ...") with pytest.raises(EinopsError): ParsedExpression("... a b c (d ...)") with pytest.raises(EinopsError): ParsedExpression("(... a) b c (d ...)") # double/missing/enclosed parenthesis ParsedExpression("(a) b c (d ...)") with pytest.raises(EinopsError): ParsedExpression("(a)) b c (d ...)") with pytest.raises(EinopsError): ParsedExpression("(a b c (d ...)") with pytest.raises(EinopsError): ParsedExpression("(a) (()) b c (d ...)") with pytest.raises(EinopsError): ParsedExpression("(a) ((b c) (d ...))") # invalid identifiers ParsedExpression("camelCase under_scored cApiTaLs ß ...") with pytest.raises(EinopsError): ParsedExpression("1a") with pytest.raises(EinopsError): ParsedExpression("_pre") with pytest.raises(EinopsError): ParsedExpression("...pre") with pytest.raises(EinopsError): ParsedExpression("pre...") def test_parse_expression(): parsed = ParsedExpression("a1 b1 c1 d1") assert parsed.identifiers == {"a1", "b1", "c1", "d1"} assert parsed.composition == [["a1"], ["b1"], ["c1"], ["d1"]] assert not parsed.has_non_unitary_anonymous_axes assert not parsed.has_ellipsis parsed = ParsedExpression("() () () ()") assert parsed.identifiers == set() assert parsed.composition == [[], [], [], []] assert not parsed.has_non_unitary_anonymous_axes assert not parsed.has_ellipsis parsed = ParsedExpression("1 1 1 ()") assert parsed.identifiers == set() assert parsed.composition == [[], [], [], []] assert not parsed.has_non_unitary_anonymous_axes assert not parsed.has_ellipsis aap = AnonymousAxisPlaceholder parsed = ParsedExpression("5 (3 4)") assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5} assert parsed.composition == [[aap(5)], [aap(3), aap(4)]] assert parsed.has_non_unitary_anonymous_axes assert not parsed.has_ellipsis parsed = ParsedExpression("5 1 (1 4) 1") assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5} assert parsed.composition == [[aap(5)], [], [aap(4)], []] parsed = ParsedExpression("name1 ... a1 12 (name2 14)") assert len(parsed.identifiers) == 6 assert parsed.identifiers.difference({"name1", _ellipsis, "a1", "name2"}).__len__() == 2 assert parsed.composition == [["name1"], _ellipsis, ["a1"], [aap(12)], ["name2", aap(14)]] assert parsed.has_non_unitary_anonymous_axes assert parsed.has_ellipsis assert not parsed.has_ellipsis_parenthesized parsed = ParsedExpression("(name1 ... a1 12) name2 14") assert len(parsed.identifiers) == 6 assert parsed.identifiers.difference({"name1", _ellipsis, "a1", "name2"}).__len__() == 2 assert parsed.composition == [["name1", _ellipsis, "a1", aap(12)], ["name2"], [aap(14)]] assert parsed.has_non_unitary_anonymous_axes assert parsed.has_ellipsis assert parsed.has_ellipsis_parenthesized