|
import sys |
|
import warnings |
|
from itertools import zip_longest |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
Callable, |
|
Dict, |
|
Generator, |
|
List, |
|
Optional, |
|
Set, |
|
Tuple, |
|
Union, |
|
) |
|
|
|
from antlr4 import TerminalNode |
|
|
|
from .errors import InterpolationResolutionError |
|
|
|
if TYPE_CHECKING: |
|
from .base import Node |
|
|
|
try: |
|
from omegaconf.grammar.gen.OmegaConfGrammarLexer import OmegaConfGrammarLexer |
|
from omegaconf.grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser |
|
from omegaconf.grammar.gen.OmegaConfGrammarParserVisitor import ( |
|
OmegaConfGrammarParserVisitor, |
|
) |
|
|
|
except ModuleNotFoundError: |
|
print( |
|
"Error importing OmegaConf's generated parsers, run `python setup.py antlr` to regenerate.", |
|
file=sys.stderr, |
|
) |
|
sys.exit(1) |
|
|
|
|
|
class GrammarVisitor(OmegaConfGrammarParserVisitor): |
|
def __init__( |
|
self, |
|
node_interpolation_callback: Callable[ |
|
[str, Optional[Set[int]]], |
|
Optional["Node"], |
|
], |
|
resolver_interpolation_callback: Callable[..., Any], |
|
memo: Optional[Set[int]], |
|
**kw: Dict[Any, Any], |
|
): |
|
""" |
|
Constructor. |
|
|
|
:param node_interpolation_callback: Callback function that is called when |
|
needing to resolve a node interpolation. This function should take a single |
|
string input which is the key's dot path (ex: `"foo.bar"`). |
|
|
|
:param resolver_interpolation_callback: Callback function that is called when |
|
needing to resolve a resolver interpolation. This function should accept |
|
three keyword arguments: `name` (str, the name of the resolver), |
|
`args` (tuple, the inputs to the resolver), and `args_str` (tuple, |
|
the string representation of the inputs to the resolver). |
|
|
|
:param kw: Additional keyword arguments to be forwarded to parent class. |
|
""" |
|
super().__init__(**kw) |
|
self.node_interpolation_callback = node_interpolation_callback |
|
self.resolver_interpolation_callback = resolver_interpolation_callback |
|
self.memo = memo |
|
|
|
def aggregateResult(self, aggregate: List[Any], nextResult: Any) -> List[Any]: |
|
raise NotImplementedError |
|
|
|
def defaultResult(self) -> List[Any]: |
|
|
|
raise NotImplementedError |
|
|
|
def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str: |
|
from ._utils import _get_value |
|
|
|
|
|
assert ctx.getChildCount() == 1 |
|
child = ctx.getChild(0) |
|
if isinstance(child, OmegaConfGrammarParser.InterpolationContext): |
|
res = _get_value(self.visitInterpolation(child)) |
|
if not isinstance(res, str): |
|
raise InterpolationResolutionError( |
|
f"The following interpolation is used to denote a config key and " |
|
f"thus should return a string, but instead returned `{res}` of " |
|
f"type `{type(res)}`: {ctx.getChild(0).getText()}" |
|
) |
|
return res |
|
else: |
|
assert isinstance(child, TerminalNode) and isinstance( |
|
child.symbol.text, str |
|
) |
|
return child.symbol.text |
|
|
|
def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any: |
|
|
|
assert ctx.getChildCount() == 2 |
|
return self.visit(ctx.getChild(0)) |
|
|
|
def visitDictKey(self, ctx: OmegaConfGrammarParser.DictKeyContext) -> Any: |
|
return self._createPrimitive(ctx) |
|
|
|
def visitDictContainer( |
|
self, ctx: OmegaConfGrammarParser.DictContainerContext |
|
) -> Dict[Any, Any]: |
|
|
|
assert ctx.getChildCount() >= 2 |
|
return dict( |
|
self.visitDictKeyValuePair(ctx.getChild(i)) |
|
for i in range(1, ctx.getChildCount() - 1, 2) |
|
) |
|
|
|
def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any: |
|
|
|
assert ctx.getChildCount() == 1 |
|
return self.visit(ctx.getChild(0)) |
|
|
|
def visitInterpolation( |
|
self, ctx: OmegaConfGrammarParser.InterpolationContext |
|
) -> Any: |
|
assert ctx.getChildCount() == 1 |
|
return self.visit(ctx.getChild(0)) |
|
|
|
def visitInterpolationNode( |
|
self, ctx: OmegaConfGrammarParser.InterpolationNodeContext |
|
) -> Optional["Node"]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert ctx.getChildCount() >= 3 |
|
|
|
inter_key_tokens = [] |
|
for child in ctx.getChildren(): |
|
if isinstance(child, TerminalNode): |
|
s = child.symbol |
|
if s.type in [ |
|
OmegaConfGrammarLexer.DOT, |
|
OmegaConfGrammarLexer.BRACKET_OPEN, |
|
OmegaConfGrammarLexer.BRACKET_CLOSE, |
|
]: |
|
inter_key_tokens.append(s.text) |
|
else: |
|
assert s.type in ( |
|
OmegaConfGrammarLexer.INTER_OPEN, |
|
OmegaConfGrammarLexer.INTER_CLOSE, |
|
) |
|
else: |
|
assert isinstance(child, OmegaConfGrammarParser.ConfigKeyContext) |
|
inter_key_tokens.append(self.visitConfigKey(child)) |
|
|
|
inter_key = "".join(inter_key_tokens) |
|
return self.node_interpolation_callback(inter_key, self.memo) |
|
|
|
def visitInterpolationResolver( |
|
self, ctx: OmegaConfGrammarParser.InterpolationResolverContext |
|
) -> Any: |
|
|
|
|
|
assert 4 <= ctx.getChildCount() <= 5 |
|
|
|
resolver_name = self.visit(ctx.getChild(1)) |
|
maybe_seq = ctx.getChild(3) |
|
args = [] |
|
args_str = [] |
|
if isinstance(maybe_seq, TerminalNode): |
|
assert maybe_seq.symbol.type == OmegaConfGrammarLexer.BRACE_CLOSE |
|
else: |
|
assert isinstance(maybe_seq, OmegaConfGrammarParser.SequenceContext) |
|
for val, txt in self.visitSequence(maybe_seq): |
|
args.append(val) |
|
args_str.append(txt) |
|
|
|
return self.resolver_interpolation_callback( |
|
name=resolver_name, |
|
args=tuple(args), |
|
args_str=tuple(args_str), |
|
) |
|
|
|
def visitDictKeyValuePair( |
|
self, ctx: OmegaConfGrammarParser.DictKeyValuePairContext |
|
) -> Tuple[Any, Any]: |
|
from ._utils import _get_value |
|
|
|
assert ctx.getChildCount() == 3 |
|
key = self.visit(ctx.getChild(0)) |
|
colon = ctx.getChild(1) |
|
assert ( |
|
isinstance(colon, TerminalNode) |
|
and colon.symbol.type == OmegaConfGrammarLexer.COLON |
|
) |
|
value = _get_value(self.visitElement(ctx.getChild(2))) |
|
return key, value |
|
|
|
def visitListContainer( |
|
self, ctx: OmegaConfGrammarParser.ListContainerContext |
|
) -> List[Any]: |
|
|
|
assert ctx.getChildCount() in (2, 3) |
|
if ctx.getChildCount() == 2: |
|
return [] |
|
sequence = ctx.getChild(1) |
|
assert isinstance(sequence, OmegaConfGrammarParser.SequenceContext) |
|
return list(val for val, _ in self.visitSequence(sequence)) |
|
|
|
def visitPrimitive(self, ctx: OmegaConfGrammarParser.PrimitiveContext) -> Any: |
|
return self._createPrimitive(ctx) |
|
|
|
def visitQuotedValue(self, ctx: OmegaConfGrammarParser.QuotedValueContext) -> str: |
|
|
|
n = ctx.getChildCount() |
|
assert n in [2, 3] |
|
return str(self.visit(ctx.getChild(1))) if n == 3 else "" |
|
|
|
def visitResolverName(self, ctx: OmegaConfGrammarParser.ResolverNameContext) -> str: |
|
from ._utils import _get_value |
|
|
|
|
|
assert ctx.getChildCount() >= 1 |
|
items = [] |
|
for child in list(ctx.getChildren())[::2]: |
|
if isinstance(child, TerminalNode): |
|
assert child.symbol.type == OmegaConfGrammarLexer.ID |
|
items.append(child.symbol.text) |
|
else: |
|
assert isinstance(child, OmegaConfGrammarParser.InterpolationContext) |
|
item = _get_value(self.visitInterpolation(child)) |
|
if not isinstance(item, str): |
|
raise InterpolationResolutionError( |
|
f"The name of a resolver must be a string, but the interpolation " |
|
f"{child.getText()} resolved to `{item}` which is of type " |
|
f"{type(item)}" |
|
) |
|
items.append(item) |
|
return ".".join(items) |
|
|
|
def visitSequence( |
|
self, ctx: OmegaConfGrammarParser.SequenceContext |
|
) -> Generator[Any, None, None]: |
|
from ._utils import _get_value |
|
|
|
|
|
assert ctx.getChildCount() >= 1 |
|
|
|
|
|
def empty_str_warning() -> None: |
|
txt = ctx.getText() |
|
warnings.warn( |
|
f"In the sequence `{txt}` some elements are missing: please replace " |
|
f"them with empty quoted strings. " |
|
f"See https://github.com/omry/omegaconf/issues/572 for details.", |
|
category=UserWarning, |
|
) |
|
|
|
is_previous_comma = True |
|
for child in ctx.getChildren(): |
|
if isinstance(child, OmegaConfGrammarParser.ElementContext): |
|
|
|
|
|
|
|
|
|
yield _get_value(self.visitElement(child)), child.getText() |
|
is_previous_comma = False |
|
else: |
|
assert ( |
|
isinstance(child, TerminalNode) |
|
and child.symbol.type == OmegaConfGrammarLexer.COMMA |
|
) |
|
if is_previous_comma: |
|
empty_str_warning() |
|
yield "", "" |
|
else: |
|
is_previous_comma = True |
|
if is_previous_comma: |
|
|
|
empty_str_warning() |
|
yield "", "" |
|
|
|
def visitSingleElement( |
|
self, ctx: OmegaConfGrammarParser.SingleElementContext |
|
) -> Any: |
|
|
|
assert ctx.getChildCount() == 2 |
|
return self.visit(ctx.getChild(0)) |
|
|
|
def visitText(self, ctx: OmegaConfGrammarParser.TextContext) -> Any: |
|
|
|
|
|
|
|
if ctx.getChildCount() == 1: |
|
c = ctx.getChild(0) |
|
if isinstance(c, OmegaConfGrammarParser.InterpolationContext): |
|
return self.visitInterpolation(c) |
|
|
|
|
|
return self._unescape(list(ctx.getChildren())) |
|
|
|
def _createPrimitive( |
|
self, |
|
ctx: Union[ |
|
OmegaConfGrammarParser.PrimitiveContext, |
|
OmegaConfGrammarParser.DictKeyContext, |
|
], |
|
) -> Any: |
|
|
|
if ctx.getChildCount() == 1: |
|
child = ctx.getChild(0) |
|
if isinstance(child, OmegaConfGrammarParser.InterpolationContext): |
|
return self.visitInterpolation(child) |
|
assert isinstance(child, TerminalNode) |
|
symbol = child.symbol |
|
|
|
if symbol.type in ( |
|
OmegaConfGrammarLexer.ID, |
|
OmegaConfGrammarLexer.UNQUOTED_CHAR, |
|
OmegaConfGrammarLexer.COLON, |
|
): |
|
return symbol.text |
|
elif symbol.type == OmegaConfGrammarLexer.NULL: |
|
return None |
|
elif symbol.type == OmegaConfGrammarLexer.INT: |
|
return int(symbol.text) |
|
elif symbol.type == OmegaConfGrammarLexer.FLOAT: |
|
return float(symbol.text) |
|
elif symbol.type == OmegaConfGrammarLexer.BOOL: |
|
return symbol.text.lower() == "true" |
|
elif symbol.type == OmegaConfGrammarLexer.ESC: |
|
return self._unescape([child]) |
|
elif symbol.type == OmegaConfGrammarLexer.WS: |
|
|
|
raise AssertionError("WS should never be reached") |
|
assert False, symbol.type |
|
|
|
return self._unescape(list(ctx.getChildren())) |
|
|
|
def _unescape( |
|
self, |
|
seq: List[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]], |
|
) -> str: |
|
""" |
|
Concatenate all symbols / interpolations in `seq`, unescaping symbols as needed. |
|
|
|
Interpolations are resolved and cast to string *WITHOUT* escaping their result |
|
(it is assumed that whatever escaping is required was already handled during the |
|
resolving of the interpolation). |
|
""" |
|
chrs = [] |
|
for node, next_node in zip_longest(seq, seq[1:]): |
|
if isinstance(node, TerminalNode): |
|
s = node.symbol |
|
if s.type == OmegaConfGrammarLexer.ESC_INTER: |
|
|
|
|
|
|
|
text = s.text[-(len(s.text) // 2 + 1) :] |
|
elif ( |
|
|
|
s.type == OmegaConfGrammarLexer.ESC |
|
or ( |
|
|
|
|
|
s.type == OmegaConfGrammarLexer.TOP_ESC |
|
and isinstance( |
|
next_node, OmegaConfGrammarParser.InterpolationContext |
|
) |
|
) |
|
or ( |
|
|
|
|
|
s.type == OmegaConfGrammarLexer.QUOTED_ESC |
|
and ( |
|
next_node is None |
|
or isinstance( |
|
next_node, OmegaConfGrammarParser.InterpolationContext |
|
) |
|
) |
|
) |
|
): |
|
text = s.text[1::2] |
|
else: |
|
text = s.text |
|
else: |
|
assert isinstance(node, OmegaConfGrammarParser.InterpolationContext) |
|
text = str(self.visitInterpolation(node)) |
|
chrs.append(text) |
|
|
|
return "".join(chrs) |
|
|