|
from __future__ import annotations |
|
|
|
from warnings import warn |
|
import inspect |
|
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning |
|
from .utils import expand_tuples |
|
import itertools as itl |
|
|
|
|
|
class MDNotImplementedError(NotImplementedError): |
|
""" A NotImplementedError for multiple dispatch """ |
|
|
|
|
|
|
|
|
|
def ambiguity_warn(dispatcher, ambiguities): |
|
""" Raise warning when ambiguity is detected |
|
|
|
Parameters |
|
---------- |
|
dispatcher : Dispatcher |
|
The dispatcher on which the ambiguity was detected |
|
ambiguities : set |
|
Set of type signature pairs that are ambiguous within this dispatcher |
|
|
|
See Also: |
|
Dispatcher.add |
|
warning_text |
|
""" |
|
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) |
|
|
|
|
|
class RaiseNotImplementedError: |
|
"""Raise ``NotImplementedError`` when called.""" |
|
|
|
def __init__(self, dispatcher): |
|
self.dispatcher = dispatcher |
|
|
|
def __call__(self, *args, **kwargs): |
|
types = tuple(type(a) for a in args) |
|
raise NotImplementedError( |
|
"Ambiguous signature for %s: <%s>" % ( |
|
self.dispatcher.name, str_signature(types) |
|
)) |
|
|
|
def ambiguity_register_error_ignore_dup(dispatcher, ambiguities): |
|
""" |
|
If super signature for ambiguous types is duplicate types, ignore it. |
|
Else, register instance of ``RaiseNotImplementedError`` for ambiguous types. |
|
|
|
Parameters |
|
---------- |
|
dispatcher : Dispatcher |
|
The dispatcher on which the ambiguity was detected |
|
ambiguities : set |
|
Set of type signature pairs that are ambiguous within this dispatcher |
|
|
|
See Also: |
|
Dispatcher.add |
|
ambiguity_warn |
|
""" |
|
for amb in ambiguities: |
|
signature = tuple(super_signature(amb)) |
|
if len(set(signature)) == 1: |
|
continue |
|
dispatcher.add( |
|
signature, RaiseNotImplementedError(dispatcher), |
|
on_ambiguity=ambiguity_register_error_ignore_dup |
|
) |
|
|
|
|
|
|
|
|
|
_unresolved_dispatchers: set[Dispatcher] = set() |
|
_resolve = [True] |
|
|
|
|
|
def halt_ordering(): |
|
_resolve[0] = False |
|
|
|
|
|
def restart_ordering(on_ambiguity=ambiguity_warn): |
|
_resolve[0] = True |
|
while _unresolved_dispatchers: |
|
dispatcher = _unresolved_dispatchers.pop() |
|
dispatcher.reorder(on_ambiguity=on_ambiguity) |
|
|
|
|
|
class Dispatcher: |
|
""" Dispatch methods based on type signature |
|
|
|
Use ``dispatch`` to add implementations |
|
|
|
Examples |
|
-------- |
|
|
|
>>> from sympy.multipledispatch import dispatch |
|
>>> @dispatch(int) |
|
... def f(x): |
|
... return x + 1 |
|
|
|
>>> @dispatch(float) |
|
... def f(x): # noqa: F811 |
|
... return x - 1 |
|
|
|
>>> f(3) |
|
4 |
|
>>> f(3.0) |
|
2.0 |
|
""" |
|
__slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc' |
|
|
|
def __init__(self, name, doc=None): |
|
self.name = self.__name__ = name |
|
self.funcs = {} |
|
self._cache = {} |
|
self.ordering = [] |
|
self.doc = doc |
|
|
|
def register(self, *types, **kwargs): |
|
""" Register dispatcher with new implementation |
|
|
|
>>> from sympy.multipledispatch.dispatcher import Dispatcher |
|
>>> f = Dispatcher('f') |
|
>>> @f.register(int) |
|
... def inc(x): |
|
... return x + 1 |
|
|
|
>>> @f.register(float) |
|
... def dec(x): |
|
... return x - 1 |
|
|
|
>>> @f.register(list) |
|
... @f.register(tuple) |
|
... def reverse(x): |
|
... return x[::-1] |
|
|
|
>>> f(1) |
|
2 |
|
|
|
>>> f(1.0) |
|
0.0 |
|
|
|
>>> f([1, 2, 3]) |
|
[3, 2, 1] |
|
""" |
|
def _(func): |
|
self.add(types, func, **kwargs) |
|
return func |
|
return _ |
|
|
|
@classmethod |
|
def get_func_params(cls, func): |
|
if hasattr(inspect, "signature"): |
|
sig = inspect.signature(func) |
|
return sig.parameters.values() |
|
|
|
@classmethod |
|
def get_func_annotations(cls, func): |
|
""" Get annotations of function positional parameters |
|
""" |
|
params = cls.get_func_params(func) |
|
if params: |
|
Parameter = inspect.Parameter |
|
|
|
params = (param for param in params |
|
if param.kind in |
|
(Parameter.POSITIONAL_ONLY, |
|
Parameter.POSITIONAL_OR_KEYWORD)) |
|
|
|
annotations = tuple( |
|
param.annotation |
|
for param in params) |
|
|
|
if not any(ann is Parameter.empty for ann in annotations): |
|
return annotations |
|
|
|
def add(self, signature, func, on_ambiguity=ambiguity_warn): |
|
""" Add new types/method pair to dispatcher |
|
|
|
>>> from sympy.multipledispatch import Dispatcher |
|
>>> D = Dispatcher('add') |
|
>>> D.add((int, int), lambda x, y: x + y) |
|
>>> D.add((float, float), lambda x, y: x + y) |
|
|
|
>>> D(1, 2) |
|
3 |
|
>>> D(1, 2.0) |
|
Traceback (most recent call last): |
|
... |
|
NotImplementedError: Could not find signature for add: <int, float> |
|
|
|
When ``add`` detects a warning it calls the ``on_ambiguity`` callback |
|
with a dispatcher/itself, and a set of ambiguous type signature pairs |
|
as inputs. See ``ambiguity_warn`` for an example. |
|
""" |
|
|
|
if not signature: |
|
annotations = self.get_func_annotations(func) |
|
if annotations: |
|
signature = annotations |
|
|
|
|
|
if any(isinstance(typ, tuple) for typ in signature): |
|
for typs in expand_tuples(signature): |
|
self.add(typs, func, on_ambiguity) |
|
return |
|
|
|
for typ in signature: |
|
if not isinstance(typ, type): |
|
str_sig = ', '.join(c.__name__ if isinstance(c, type) |
|
else str(c) for c in signature) |
|
raise TypeError("Tried to dispatch on non-type: %s\n" |
|
"In signature: <%s>\n" |
|
"In function: %s" % |
|
(typ, str_sig, self.name)) |
|
|
|
self.funcs[signature] = func |
|
self.reorder(on_ambiguity=on_ambiguity) |
|
self._cache.clear() |
|
|
|
def reorder(self, on_ambiguity=ambiguity_warn): |
|
if _resolve[0]: |
|
self.ordering = ordering(self.funcs) |
|
amb = ambiguities(self.funcs) |
|
if amb: |
|
on_ambiguity(self, amb) |
|
else: |
|
_unresolved_dispatchers.add(self) |
|
|
|
def __call__(self, *args, **kwargs): |
|
types = tuple([type(arg) for arg in args]) |
|
try: |
|
func = self._cache[types] |
|
except KeyError: |
|
func = self.dispatch(*types) |
|
if not func: |
|
raise NotImplementedError( |
|
'Could not find signature for %s: <%s>' % |
|
(self.name, str_signature(types))) |
|
self._cache[types] = func |
|
try: |
|
return func(*args, **kwargs) |
|
|
|
except MDNotImplementedError: |
|
funcs = self.dispatch_iter(*types) |
|
next(funcs) |
|
for func in funcs: |
|
try: |
|
return func(*args, **kwargs) |
|
except MDNotImplementedError: |
|
pass |
|
raise NotImplementedError("Matching functions for " |
|
"%s: <%s> found, but none completed successfully" |
|
% (self.name, str_signature(types))) |
|
|
|
def __str__(self): |
|
return "<dispatched %s>" % self.name |
|
__repr__ = __str__ |
|
|
|
def dispatch(self, *types): |
|
""" Deterimine appropriate implementation for this type signature |
|
|
|
This method is internal. Users should call this object as a function. |
|
Implementation resolution occurs within the ``__call__`` method. |
|
|
|
>>> from sympy.multipledispatch import dispatch |
|
>>> @dispatch(int) |
|
... def inc(x): |
|
... return x + 1 |
|
|
|
>>> implementation = inc.dispatch(int) |
|
>>> implementation(3) |
|
4 |
|
|
|
>>> print(inc.dispatch(float)) |
|
None |
|
|
|
See Also: |
|
``sympy.multipledispatch.conflict`` - module to determine resolution order |
|
""" |
|
|
|
if types in self.funcs: |
|
return self.funcs[types] |
|
|
|
try: |
|
return next(self.dispatch_iter(*types)) |
|
except StopIteration: |
|
return None |
|
|
|
def dispatch_iter(self, *types): |
|
n = len(types) |
|
for signature in self.ordering: |
|
if len(signature) == n and all(map(issubclass, types, signature)): |
|
result = self.funcs[signature] |
|
yield result |
|
|
|
def resolve(self, types): |
|
""" Deterimine appropriate implementation for this type signature |
|
|
|
.. deprecated:: 0.4.4 |
|
Use ``dispatch(*types)`` instead |
|
""" |
|
warn("resolve() is deprecated, use dispatch(*types)", |
|
DeprecationWarning) |
|
|
|
return self.dispatch(*types) |
|
|
|
def __getstate__(self): |
|
return {'name': self.name, |
|
'funcs': self.funcs} |
|
|
|
def __setstate__(self, d): |
|
self.name = d['name'] |
|
self.funcs = d['funcs'] |
|
self.ordering = ordering(self.funcs) |
|
self._cache = {} |
|
|
|
@property |
|
def __doc__(self): |
|
docs = ["Multiply dispatched method: %s" % self.name] |
|
|
|
if self.doc: |
|
docs.append(self.doc) |
|
|
|
other = [] |
|
for sig in self.ordering[::-1]: |
|
func = self.funcs[sig] |
|
if func.__doc__: |
|
s = 'Inputs: <%s>\n' % str_signature(sig) |
|
s += '-' * len(s) + '\n' |
|
s += func.__doc__.strip() |
|
docs.append(s) |
|
else: |
|
other.append(str_signature(sig)) |
|
|
|
if other: |
|
docs.append('Other signatures:\n ' + '\n '.join(other)) |
|
|
|
return '\n\n'.join(docs) |
|
|
|
def _help(self, *args): |
|
return self.dispatch(*map(type, args)).__doc__ |
|
|
|
def help(self, *args, **kwargs): |
|
""" Print docstring for the function corresponding to inputs """ |
|
print(self._help(*args)) |
|
|
|
def _source(self, *args): |
|
func = self.dispatch(*map(type, args)) |
|
if not func: |
|
raise TypeError("No function found") |
|
return source(func) |
|
|
|
def source(self, *args, **kwargs): |
|
""" Print source code for the function corresponding to inputs """ |
|
print(self._source(*args)) |
|
|
|
|
|
def source(func): |
|
s = 'File: %s\n\n' % inspect.getsourcefile(func) |
|
s = s + inspect.getsource(func) |
|
return s |
|
|
|
|
|
class MethodDispatcher(Dispatcher): |
|
""" Dispatch methods based on type signature |
|
|
|
See Also: |
|
Dispatcher |
|
""" |
|
|
|
@classmethod |
|
def get_func_params(cls, func): |
|
if hasattr(inspect, "signature"): |
|
sig = inspect.signature(func) |
|
return itl.islice(sig.parameters.values(), 1, None) |
|
|
|
def __get__(self, instance, owner): |
|
self.obj = instance |
|
self.cls = owner |
|
return self |
|
|
|
def __call__(self, *args, **kwargs): |
|
types = tuple([type(arg) for arg in args]) |
|
func = self.dispatch(*types) |
|
if not func: |
|
raise NotImplementedError('Could not find signature for %s: <%s>' % |
|
(self.name, str_signature(types))) |
|
return func(self.obj, *args, **kwargs) |
|
|
|
|
|
def str_signature(sig): |
|
""" String representation of type signature |
|
|
|
>>> from sympy.multipledispatch.dispatcher import str_signature |
|
>>> str_signature((int, float)) |
|
'int, float' |
|
""" |
|
return ', '.join(cls.__name__ for cls in sig) |
|
|
|
|
|
def warning_text(name, amb): |
|
""" The text for ambiguity warnings """ |
|
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) |
|
text += "The following signatures may result in ambiguous behavior:\n" |
|
for pair in amb: |
|
text += "\t" + \ |
|
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" |
|
text += "\n\nConsider making the following additions:\n\n" |
|
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) |
|
+ ')\ndef %s(...)' % name for s in amb]) |
|
return text |
|
|