File size: 6,746 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from einops import EinopsError
import keyword
import warnings
from typing import List, Optional, Set, Tuple, Union
_ellipsis: str = "β¦" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
class AnonymousAxis(object):
"""Important thing: all instances of this class are not equal to each other"""
def __init__(self, value: str):
self.value = int(value)
if self.value <= 1:
if self.value == 1:
raise EinopsError("No need to create anonymous axis of length 1. Report this as an issue")
else:
raise EinopsError("Anonymous axis should have positive length, not {}".format(self.value))
def __repr__(self):
return "{}-axis".format(str(self.value))
class ParsedExpression:
"""
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
and keeps some information important for downstream
"""
def __init__(self, expression: str, *, allow_underscore: bool = False, allow_duplicates: bool = False):
self.has_ellipsis: bool = False
self.has_ellipsis_parenthesized: Optional[bool] = None
self.identifiers: Set[str] = set()
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
self.has_non_unitary_anonymous_axes: bool = False
# composition keeps structure of composite axes, see how different corner cases are handled in tests
self.composition: List[Union[List[str], str]] = []
if "." in expression:
if "..." not in expression:
raise EinopsError("Expression may contain dots only inside ellipsis (...)")
if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
raise EinopsError(
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
)
expression = expression.replace("...", _ellipsis)
self.has_ellipsis = True
bracket_group: Optional[List[str]] = None
def add_axis_name(x):
if x in self.identifiers:
if not (allow_underscore and x == "_") and not allow_duplicates:
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
if x == _ellipsis:
self.identifiers.add(_ellipsis)
if bracket_group is None:
self.composition.append(_ellipsis)
self.has_ellipsis_parenthesized = False
else:
bracket_group.append(_ellipsis)
self.has_ellipsis_parenthesized = True
else:
is_number = str.isdecimal(x)
if is_number and int(x) == 1:
# handling the case of anonymous axis of length 1
if bracket_group is None:
self.composition.append([])
else:
pass # no need to think about 1s inside parenthesis
return
is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore)
if not (is_number or is_axis_name):
raise EinopsError("Invalid axis identifier: {}\n{}".format(x, reason))
if is_number:
x = AnonymousAxis(x)
self.identifiers.add(x)
if is_number:
self.has_non_unitary_anonymous_axes = True
if bracket_group is None:
self.composition.append([x])
else:
bracket_group.append(x)
current_identifier = None
for char in expression:
if char in "() ":
if current_identifier is not None:
add_axis_name(current_identifier)
current_identifier = None
if char == "(":
if bracket_group is not None:
raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)")
bracket_group = []
elif char == ")":
if bracket_group is None:
raise EinopsError("Brackets are not balanced")
self.composition.append(bracket_group)
bracket_group = None
elif str.isalnum(char) or char in ["_", _ellipsis]:
if current_identifier is None:
current_identifier = char
else:
current_identifier += char
else:
raise EinopsError("Unknown character '{}'".format(char))
if bracket_group is not None:
raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression))
if current_identifier is not None:
add_axis_name(current_identifier)
def flat_axes_order(self) -> List:
result = []
for composed_axis in self.composition:
assert isinstance(composed_axis, list), "does not work with ellipsis"
for axis in composed_axis:
result.append(axis)
return result
def has_composed_axes(self) -> bool:
# this will ignore 1 inside brackets
for axes in self.composition:
if isinstance(axes, list) and len(axes) > 1:
return True
return False
@staticmethod
def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]:
if not str.isidentifier(name):
return False, "not a valid python identifier"
elif name[0] == "_" or name[-1] == "_":
if name == "_" and allow_underscore:
return True, ""
return False, "axis name should should not start or end with underscore"
else:
if keyword.iskeyword(name):
warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning)
if name in ["axis"]:
warnings.warn(
"It is discouraged to use 'axis' as an axis name " "and will raise an error in future",
FutureWarning,
)
return True, ""
@staticmethod
def check_axis_name(name: str) -> bool:
"""
Valid axes names are python identifiers except keywords,
and additionally should not start or end with underscore
"""
is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name)
return is_valid
|