File size: 4,835 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from __future__ import annotations
from copy import copy
from typing import Any
from tomlkit.exceptions import ParseError
from tomlkit.exceptions import UnexpectedCharError
from tomlkit.toml_char import TOMLChar
class _State:
def __init__(
self,
source: Source,
save_marker: bool | None = False,
restore: bool | None = False,
) -> None:
self._source = source
self._save_marker = save_marker
self.restore = restore
def __enter__(self) -> _State:
# Entering this context manager - save the state
self._chars = copy(self._source._chars)
self._idx = self._source._idx
self._current = self._source._current
self._marker = self._source._marker
return self
def __exit__(self, exception_type, exception_val, trace):
# Exiting this context manager - restore the prior state
if self.restore or exception_type:
self._source._chars = self._chars
self._source._idx = self._idx
self._source._current = self._current
if self._save_marker:
self._source._marker = self._marker
class _StateHandler:
"""
State preserver for the Parser.
"""
def __init__(self, source: Source) -> None:
self._source = source
self._states = []
def __call__(self, *args, **kwargs):
return _State(self._source, *args, **kwargs)
def __enter__(self) -> _State:
state = self()
self._states.append(state)
return state.__enter__()
def __exit__(self, exception_type, exception_val, trace):
state = self._states.pop()
return state.__exit__(exception_type, exception_val, trace)
class Source(str):
EOF = TOMLChar("\0")
def __init__(self, _: str) -> None:
super().__init__()
# Collection of TOMLChars
self._chars = iter([(i, TOMLChar(c)) for i, c in enumerate(self)])
self._idx = 0
self._marker = 0
self._current = TOMLChar("")
self._state = _StateHandler(self)
self.inc()
def reset(self):
# initialize both idx and current
self.inc()
# reset marker
self.mark()
@property
def state(self) -> _StateHandler:
return self._state
@property
def idx(self) -> int:
return self._idx
@property
def current(self) -> TOMLChar:
return self._current
@property
def marker(self) -> int:
return self._marker
def extract(self) -> str:
"""
Extracts the value between marker and index
"""
return self[self._marker : self._idx]
def inc(self, exception: type[ParseError] | None = None) -> bool:
"""
Increments the parser if the end of the input has not been reached.
Returns whether or not it was able to advance.
"""
try:
self._idx, self._current = next(self._chars)
return True
except StopIteration:
self._idx = len(self)
self._current = self.EOF
if exception:
raise self.parse_error(exception) from None
return False
def inc_n(self, n: int, exception: type[ParseError] | None = None) -> bool:
"""
Increments the parser by n characters
if the end of the input has not been reached.
"""
return all(self.inc(exception=exception) for _ in range(n))
def consume(self, chars, min=0, max=-1):
"""
Consume chars until min/max is satisfied is valid.
"""
while self.current in chars and max != 0:
min -= 1
max -= 1
if not self.inc():
break
# failed to consume minimum number of characters
if min > 0:
raise self.parse_error(UnexpectedCharError, self.current)
def end(self) -> bool:
"""
Returns True if the parser has reached the end of the input.
"""
return self._current is self.EOF
def mark(self) -> None:
"""
Sets the marker to the index's current position
"""
self._marker = self._idx
def parse_error(
self,
exception: type[ParseError] = ParseError,
*args: Any,
**kwargs: Any,
) -> ParseError:
"""
Creates a generic "parse error" at the current position.
"""
line, col = self._to_linecol()
return exception(line, col, *args, **kwargs)
def _to_linecol(self) -> tuple[int, int]:
cur = 0
for i, line in enumerate(self.splitlines()):
if cur + len(line) + 1 > self.idx:
return (i + 1, self.idx - cur)
cur += len(line) + 1
return len(self.splitlines()), 0
|