|
from sympy.core import S |
|
from sympy.core.function import Lambda |
|
from sympy.core.power import Pow |
|
from .pycode import PythonCodePrinter, _known_functions_math, _print_known_const, _print_known_func, _unpack_integral_limits, ArrayPrinter |
|
from .codeprinter import CodePrinter |
|
|
|
|
|
_not_in_numpy = 'erf erfc factorial gamma loggamma'.split() |
|
_in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy] |
|
_known_functions_numpy = dict(_in_numpy, **{ |
|
'acos': 'arccos', |
|
'acosh': 'arccosh', |
|
'asin': 'arcsin', |
|
'asinh': 'arcsinh', |
|
'atan': 'arctan', |
|
'atan2': 'arctan2', |
|
'atanh': 'arctanh', |
|
'exp2': 'exp2', |
|
'sign': 'sign', |
|
'logaddexp': 'logaddexp', |
|
'logaddexp2': 'logaddexp2', |
|
'isinf': 'isinf', |
|
'isnan': 'isnan', |
|
|
|
}) |
|
_known_constants_numpy = { |
|
'Exp1': 'e', |
|
'Pi': 'pi', |
|
'EulerGamma': 'euler_gamma', |
|
'NaN': 'nan', |
|
'Infinity': 'inf', |
|
} |
|
|
|
_numpy_known_functions = {k: 'numpy.' + v for k, v in _known_functions_numpy.items()} |
|
_numpy_known_constants = {k: 'numpy.' + v for k, v in _known_constants_numpy.items()} |
|
|
|
class NumPyPrinter(ArrayPrinter, PythonCodePrinter): |
|
""" |
|
Numpy printer which handles vectorized piecewise functions, |
|
logical operators, etc. |
|
""" |
|
|
|
_module = 'numpy' |
|
_kf = _numpy_known_functions |
|
_kc = _numpy_known_constants |
|
|
|
def __init__(self, settings=None): |
|
""" |
|
`settings` is passed to CodePrinter.__init__() |
|
`module` specifies the array module to use, currently 'NumPy', 'CuPy' |
|
or 'JAX'. |
|
""" |
|
self.language = "Python with {}".format(self._module) |
|
self.printmethod = "_{}code".format(self._module) |
|
|
|
self._kf = {**PythonCodePrinter._kf, **self._kf} |
|
|
|
super().__init__(settings=settings) |
|
|
|
|
|
def _print_seq(self, seq): |
|
"General sequence printer: converts to tuple" |
|
|
|
|
|
delimiter=', ' |
|
return '({},)'.format(delimiter.join(self._print(item) for item in seq)) |
|
|
|
def _print_NegativeInfinity(self, expr): |
|
return '-' + self._print(S.Infinity) |
|
|
|
def _print_MatMul(self, expr): |
|
"Matrix multiplication printer" |
|
if expr.as_coeff_matrices()[0] is not S.One: |
|
expr_list = expr.as_coeff_matrices()[1]+[(expr.as_coeff_matrices()[0])] |
|
return '({})'.format(').dot('.join(self._print(i) for i in expr_list)) |
|
return '({})'.format(').dot('.join(self._print(i) for i in expr.args)) |
|
|
|
def _print_MatPow(self, expr): |
|
"Matrix power printer" |
|
return '{}({}, {})'.format(self._module_format(self._module + '.linalg.matrix_power'), |
|
self._print(expr.args[0]), self._print(expr.args[1])) |
|
|
|
def _print_Inverse(self, expr): |
|
"Matrix inverse printer" |
|
return '{}({})'.format(self._module_format(self._module + '.linalg.inv'), |
|
self._print(expr.args[0])) |
|
|
|
def _print_DotProduct(self, expr): |
|
|
|
|
|
arg1, arg2 = expr.args |
|
if arg1.shape[0] != 1: |
|
arg1 = arg1.T |
|
if arg2.shape[1] != 1: |
|
arg2 = arg2.T |
|
|
|
return "%s(%s, %s)" % (self._module_format(self._module + '.dot'), |
|
self._print(arg1), |
|
self._print(arg2)) |
|
|
|
def _print_MatrixSolve(self, expr): |
|
return "%s(%s, %s)" % (self._module_format(self._module + '.linalg.solve'), |
|
self._print(expr.matrix), |
|
self._print(expr.vector)) |
|
|
|
def _print_ZeroMatrix(self, expr): |
|
return '{}({})'.format(self._module_format(self._module + '.zeros'), |
|
self._print(expr.shape)) |
|
|
|
def _print_OneMatrix(self, expr): |
|
return '{}({})'.format(self._module_format(self._module + '.ones'), |
|
self._print(expr.shape)) |
|
|
|
def _print_FunctionMatrix(self, expr): |
|
from sympy.abc import i, j |
|
lamda = expr.lamda |
|
if not isinstance(lamda, Lambda): |
|
lamda = Lambda((i, j), lamda(i, j)) |
|
return '{}(lambda {}: {}, {})'.format(self._module_format(self._module + '.fromfunction'), |
|
', '.join(self._print(arg) for arg in lamda.args[0]), |
|
self._print(lamda.args[1]), self._print(expr.shape)) |
|
|
|
def _print_HadamardProduct(self, expr): |
|
func = self._module_format(self._module + '.multiply') |
|
return ''.join('{}({}, '.format(func, self._print(arg)) \ |
|
for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]), |
|
')' * (len(expr.args) - 1)) |
|
|
|
def _print_KroneckerProduct(self, expr): |
|
func = self._module_format(self._module + '.kron') |
|
return ''.join('{}({}, '.format(func, self._print(arg)) \ |
|
for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]), |
|
')' * (len(expr.args) - 1)) |
|
|
|
def _print_Adjoint(self, expr): |
|
return '{}({}({}))'.format( |
|
self._module_format(self._module + '.conjugate'), |
|
self._module_format(self._module + '.transpose'), |
|
self._print(expr.args[0])) |
|
|
|
def _print_DiagonalOf(self, expr): |
|
vect = '{}({})'.format( |
|
self._module_format(self._module + '.diag'), |
|
self._print(expr.arg)) |
|
return '{}({}, (-1, 1))'.format( |
|
self._module_format(self._module + '.reshape'), vect) |
|
|
|
def _print_DiagMatrix(self, expr): |
|
return '{}({})'.format(self._module_format(self._module + '.diagflat'), |
|
self._print(expr.args[0])) |
|
|
|
def _print_DiagonalMatrix(self, expr): |
|
return '{}({}, {}({}, {}))'.format(self._module_format(self._module + '.multiply'), |
|
self._print(expr.arg), self._module_format(self._module + '.eye'), |
|
self._print(expr.shape[0]), self._print(expr.shape[1])) |
|
|
|
def _print_Piecewise(self, expr): |
|
"Piecewise function printer" |
|
from sympy.logic.boolalg import ITE, simplify_logic |
|
def print_cond(cond): |
|
""" Problem having an ITE in the cond. """ |
|
if cond.has(ITE): |
|
return self._print(simplify_logic(cond)) |
|
else: |
|
return self._print(cond) |
|
exprs = '[{}]'.format(','.join(self._print(arg.expr) for arg in expr.args)) |
|
conds = '[{}]'.format(','.join(print_cond(arg.cond) for arg in expr.args)) |
|
|
|
|
|
|
|
|
|
return '{}({}, {}, default={})'.format( |
|
self._module_format(self._module + '.select'), conds, exprs, |
|
self._print(S.NaN)) |
|
|
|
def _print_Relational(self, expr): |
|
"Relational printer for Equality and Unequality" |
|
op = { |
|
'==' :'equal', |
|
'!=' :'not_equal', |
|
'<' :'less', |
|
'<=' :'less_equal', |
|
'>' :'greater', |
|
'>=' :'greater_equal', |
|
} |
|
if expr.rel_op in op: |
|
lhs = self._print(expr.lhs) |
|
rhs = self._print(expr.rhs) |
|
return '{op}({lhs}, {rhs})'.format(op=self._module_format(self._module + '.'+op[expr.rel_op]), |
|
lhs=lhs, rhs=rhs) |
|
return super()._print_Relational(expr) |
|
|
|
def _print_And(self, expr): |
|
"Logical And printer" |
|
|
|
|
|
|
|
return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_and'), ','.join(self._print(i) for i in expr.args)) |
|
|
|
def _print_Or(self, expr): |
|
"Logical Or printer" |
|
|
|
|
|
|
|
return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_or'), ','.join(self._print(i) for i in expr.args)) |
|
|
|
def _print_Not(self, expr): |
|
"Logical Not printer" |
|
|
|
|
|
|
|
return '{}({})'.format(self._module_format(self._module + '.logical_not'), ','.join(self._print(i) for i in expr.args)) |
|
|
|
def _print_Pow(self, expr, rational=False): |
|
|
|
if expr.exp.is_integer and expr.exp.is_negative: |
|
expr = Pow(expr.base, expr.exp.evalf(), evaluate=False) |
|
return self._hprint_Pow(expr, rational=rational, sqrt=self._module + '.sqrt') |
|
|
|
def _helper_minimum_maximum(self, op: str, *args): |
|
if len(args) == 0: |
|
raise NotImplementedError(f"Need at least one argument for {op}") |
|
elif len(args) == 1: |
|
return self._print(args[0]) |
|
_reduce = self._module_format('functools.reduce') |
|
s_args = [self._print(arg) for arg in args] |
|
return f"{_reduce}({op}, [{', '.join(s_args)}])" |
|
|
|
def _print_Min(self, expr): |
|
return self._print_minimum(expr) |
|
|
|
def _print_amin(self, expr): |
|
return '{}({}, axis={})'.format(self._module_format(self._module + '.amin'), self._print(expr.array), self._print(expr.axis)) |
|
|
|
def _print_minimum(self, expr): |
|
op = self._module_format(self._module + '.minimum') |
|
return self._helper_minimum_maximum(op, *expr.args) |
|
|
|
def _print_Max(self, expr): |
|
return self._print_maximum(expr) |
|
|
|
def _print_amax(self, expr): |
|
return '{}({}, axis={})'.format(self._module_format(self._module + '.amax'), self._print(expr.array), self._print(expr.axis)) |
|
|
|
def _print_maximum(self, expr): |
|
op = self._module_format(self._module + '.maximum') |
|
return self._helper_minimum_maximum(op, *expr.args) |
|
|
|
def _print_arg(self, expr): |
|
return "%s(%s)" % (self._module_format(self._module + '.angle'), self._print(expr.args[0])) |
|
|
|
def _print_im(self, expr): |
|
return "%s(%s)" % (self._module_format(self._module + '.imag'), self._print(expr.args[0])) |
|
|
|
def _print_Mod(self, expr): |
|
return "%s(%s)" % (self._module_format(self._module + '.mod'), ', '.join( |
|
(self._print(arg) for arg in expr.args))) |
|
|
|
def _print_re(self, expr): |
|
return "%s(%s)" % (self._module_format(self._module + '.real'), self._print(expr.args[0])) |
|
|
|
def _print_sinc(self, expr): |
|
return "%s(%s)" % (self._module_format(self._module + '.sinc'), self._print(expr.args[0]/S.Pi)) |
|
|
|
def _print_MatrixBase(self, expr): |
|
if 0 in expr.shape: |
|
func = self._module_format(f'{self._module}.{self._zeros}') |
|
return f"{func}({self._print(expr.shape)})" |
|
func = self.known_functions.get(expr.__class__.__name__, None) |
|
if func is None: |
|
func = self._module_format(f'{self._module}.array') |
|
return "%s(%s)" % (func, self._print(expr.tolist())) |
|
|
|
def _print_Identity(self, expr): |
|
shape = expr.shape |
|
if all(dim.is_Integer for dim in shape): |
|
return "%s(%s)" % (self._module_format(self._module + '.eye'), self._print(expr.shape[0])) |
|
else: |
|
raise NotImplementedError("Symbolic matrix dimensions are not yet supported for identity matrices") |
|
|
|
def _print_BlockMatrix(self, expr): |
|
return '{}({})'.format(self._module_format(self._module + '.block'), |
|
self._print(expr.args[0].tolist())) |
|
|
|
def _print_NDimArray(self, expr): |
|
if expr.rank() == 0: |
|
func = self._module_format(f'{self._module}.array') |
|
return f"{func}({self._print(expr[()])})" |
|
if 0 in expr.shape: |
|
func = self._module_format(f'{self._module}.{self._zeros}') |
|
return f"{func}({self._print(expr.shape)})" |
|
func = self._module_format(f'{self._module}.array') |
|
return f"{func}({self._print(expr.tolist())})" |
|
|
|
_add = "add" |
|
_einsum = "einsum" |
|
_transpose = "transpose" |
|
_ones = "ones" |
|
_zeros = "zeros" |
|
|
|
_print_lowergamma = CodePrinter._print_not_supported |
|
_print_uppergamma = CodePrinter._print_not_supported |
|
_print_fresnelc = CodePrinter._print_not_supported |
|
_print_fresnels = CodePrinter._print_not_supported |
|
|
|
for func in _numpy_known_functions: |
|
setattr(NumPyPrinter, f'_print_{func}', _print_known_func) |
|
|
|
for const in _numpy_known_constants: |
|
setattr(NumPyPrinter, f'_print_{const}', _print_known_const) |
|
|
|
|
|
_known_functions_scipy_special = { |
|
'Ei': 'expi', |
|
'erf': 'erf', |
|
'erfc': 'erfc', |
|
'besselj': 'jv', |
|
'bessely': 'yv', |
|
'besseli': 'iv', |
|
'besselk': 'kv', |
|
'cosm1': 'cosm1', |
|
'powm1': 'powm1', |
|
'factorial': 'factorial', |
|
'gamma': 'gamma', |
|
'loggamma': 'gammaln', |
|
'digamma': 'psi', |
|
'polygamma': 'polygamma', |
|
'RisingFactorial': 'poch', |
|
'jacobi': 'eval_jacobi', |
|
'gegenbauer': 'eval_gegenbauer', |
|
'chebyshevt': 'eval_chebyt', |
|
'chebyshevu': 'eval_chebyu', |
|
'legendre': 'eval_legendre', |
|
'hermite': 'eval_hermite', |
|
'laguerre': 'eval_laguerre', |
|
'assoc_laguerre': 'eval_genlaguerre', |
|
'beta': 'beta', |
|
'LambertW' : 'lambertw', |
|
} |
|
|
|
_known_constants_scipy_constants = { |
|
'GoldenRatio': 'golden_ratio', |
|
'Pi': 'pi', |
|
} |
|
_scipy_known_functions = {k : "scipy.special." + v for k, v in _known_functions_scipy_special.items()} |
|
_scipy_known_constants = {k : "scipy.constants." + v for k, v in _known_constants_scipy_constants.items()} |
|
|
|
class SciPyPrinter(NumPyPrinter): |
|
|
|
_kf = {**NumPyPrinter._kf, **_scipy_known_functions} |
|
_kc = {**NumPyPrinter._kc, **_scipy_known_constants} |
|
|
|
def __init__(self, settings=None): |
|
super().__init__(settings=settings) |
|
self.language = "Python with SciPy and NumPy" |
|
|
|
def _print_SparseRepMatrix(self, expr): |
|
i, j, data = [], [], [] |
|
for (r, c), v in expr.todok().items(): |
|
i.append(r) |
|
j.append(c) |
|
data.append(v) |
|
|
|
return "{name}(({data}, ({i}, {j})), shape={shape})".format( |
|
name=self._module_format('scipy.sparse.coo_matrix'), |
|
data=data, i=i, j=j, shape=expr.shape |
|
) |
|
|
|
_print_ImmutableSparseMatrix = _print_SparseRepMatrix |
|
|
|
|
|
def _print_assoc_legendre(self, expr): |
|
return "{0}({2}, {1}, {3})".format( |
|
self._module_format('scipy.special.lpmv'), |
|
self._print(expr.args[0]), |
|
self._print(expr.args[1]), |
|
self._print(expr.args[2])) |
|
|
|
def _print_lowergamma(self, expr): |
|
return "{0}({2})*{1}({2}, {3})".format( |
|
self._module_format('scipy.special.gamma'), |
|
self._module_format('scipy.special.gammainc'), |
|
self._print(expr.args[0]), |
|
self._print(expr.args[1])) |
|
|
|
def _print_uppergamma(self, expr): |
|
return "{0}({2})*{1}({2}, {3})".format( |
|
self._module_format('scipy.special.gamma'), |
|
self._module_format('scipy.special.gammaincc'), |
|
self._print(expr.args[0]), |
|
self._print(expr.args[1])) |
|
|
|
def _print_betainc(self, expr): |
|
betainc = self._module_format('scipy.special.betainc') |
|
beta = self._module_format('scipy.special.beta') |
|
args = [self._print(arg) for arg in expr.args] |
|
return f"({betainc}({args[0]}, {args[1]}, {args[3]}) - {betainc}({args[0]}, {args[1]}, {args[2]})) \ |
|
* {beta}({args[0]}, {args[1]})" |
|
|
|
def _print_betainc_regularized(self, expr): |
|
return "{0}({1}, {2}, {4}) - {0}({1}, {2}, {3})".format( |
|
self._module_format('scipy.special.betainc'), |
|
self._print(expr.args[0]), |
|
self._print(expr.args[1]), |
|
self._print(expr.args[2]), |
|
self._print(expr.args[3])) |
|
|
|
def _print_fresnels(self, expr): |
|
return "{}({})[0]".format( |
|
self._module_format("scipy.special.fresnel"), |
|
self._print(expr.args[0])) |
|
|
|
def _print_fresnelc(self, expr): |
|
return "{}({})[1]".format( |
|
self._module_format("scipy.special.fresnel"), |
|
self._print(expr.args[0])) |
|
|
|
def _print_airyai(self, expr): |
|
return "{}({})[0]".format( |
|
self._module_format("scipy.special.airy"), |
|
self._print(expr.args[0])) |
|
|
|
def _print_airyaiprime(self, expr): |
|
return "{}({})[1]".format( |
|
self._module_format("scipy.special.airy"), |
|
self._print(expr.args[0])) |
|
|
|
def _print_airybi(self, expr): |
|
return "{}({})[2]".format( |
|
self._module_format("scipy.special.airy"), |
|
self._print(expr.args[0])) |
|
|
|
def _print_airybiprime(self, expr): |
|
return "{}({})[3]".format( |
|
self._module_format("scipy.special.airy"), |
|
self._print(expr.args[0])) |
|
|
|
def _print_bernoulli(self, expr): |
|
|
|
return self._print(expr._eval_rewrite_as_zeta(*expr.args)) |
|
|
|
def _print_harmonic(self, expr): |
|
return self._print(expr._eval_rewrite_as_zeta(*expr.args)) |
|
|
|
def _print_Integral(self, e): |
|
integration_vars, limits = _unpack_integral_limits(e) |
|
|
|
if len(limits) == 1: |
|
|
|
module_str = self._module_format("scipy.integrate.quad") |
|
limit_str = "%s, %s" % tuple(map(self._print, limits[0])) |
|
else: |
|
module_str = self._module_format("scipy.integrate.nquad") |
|
limit_str = "({})".format(", ".join( |
|
"(%s, %s)" % tuple(map(self._print, l)) for l in limits)) |
|
|
|
return "{}(lambda {}: {}, {})[0]".format( |
|
module_str, |
|
", ".join(map(self._print, integration_vars)), |
|
self._print(e.args[0]), |
|
limit_str) |
|
|
|
def _print_Si(self, expr): |
|
return "{}({})[0]".format( |
|
self._module_format("scipy.special.sici"), |
|
self._print(expr.args[0])) |
|
|
|
def _print_Ci(self, expr): |
|
return "{}({})[1]".format( |
|
self._module_format("scipy.special.sici"), |
|
self._print(expr.args[0])) |
|
|
|
for func in _scipy_known_functions: |
|
setattr(SciPyPrinter, f'_print_{func}', _print_known_func) |
|
|
|
for const in _scipy_known_constants: |
|
setattr(SciPyPrinter, f'_print_{const}', _print_known_const) |
|
|
|
|
|
_cupy_known_functions = {k : "cupy." + v for k, v in _known_functions_numpy.items()} |
|
_cupy_known_constants = {k : "cupy." + v for k, v in _known_constants_numpy.items()} |
|
|
|
class CuPyPrinter(NumPyPrinter): |
|
""" |
|
CuPy printer which handles vectorized piecewise functions, |
|
logical operators, etc. |
|
""" |
|
|
|
_module = 'cupy' |
|
_kf = _cupy_known_functions |
|
_kc = _cupy_known_constants |
|
|
|
def __init__(self, settings=None): |
|
super().__init__(settings=settings) |
|
|
|
for func in _cupy_known_functions: |
|
setattr(CuPyPrinter, f'_print_{func}', _print_known_func) |
|
|
|
for const in _cupy_known_constants: |
|
setattr(CuPyPrinter, f'_print_{const}', _print_known_const) |
|
|
|
|
|
_jax_known_functions = {k: 'jax.numpy.' + v for k, v in _known_functions_numpy.items()} |
|
_jax_known_constants = {k: 'jax.numpy.' + v for k, v in _known_constants_numpy.items()} |
|
|
|
class JaxPrinter(NumPyPrinter): |
|
""" |
|
JAX printer which handles vectorized piecewise functions, |
|
logical operators, etc. |
|
""" |
|
_module = "jax.numpy" |
|
|
|
_kf = _jax_known_functions |
|
_kc = _jax_known_constants |
|
|
|
def __init__(self, settings=None): |
|
super().__init__(settings=settings) |
|
self.printmethod = '_jaxcode' |
|
|
|
|
|
def _print_And(self, expr): |
|
"Logical And printer" |
|
return "{}({}.asarray([{}]), axis=0)".format( |
|
self._module_format(self._module + ".all"), |
|
self._module_format(self._module), |
|
",".join(self._print(i) for i in expr.args), |
|
) |
|
|
|
def _print_Or(self, expr): |
|
"Logical Or printer" |
|
return "{}({}.asarray([{}]), axis=0)".format( |
|
self._module_format(self._module + ".any"), |
|
self._module_format(self._module), |
|
",".join(self._print(i) for i in expr.args), |
|
) |
|
|
|
for func in _jax_known_functions: |
|
setattr(JaxPrinter, f'_print_{func}', _print_known_func) |
|
|
|
for const in _jax_known_constants: |
|
setattr(JaxPrinter, f'_print_{const}', _print_known_const) |
|
|