|
import sympy.codegen |
|
import sympy.codegen.cfunctions |
|
from sympy.external.importtools import version_tuple |
|
from collections.abc import Iterable |
|
|
|
from sympy.core.mul import Mul |
|
from sympy.core.singleton import S |
|
from sympy.codegen.cfunctions import Sqrt |
|
from sympy.external import import_module |
|
from sympy.printing.precedence import PRECEDENCE |
|
from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter |
|
import sympy |
|
|
|
tensorflow = import_module('tensorflow') |
|
|
|
class TensorflowPrinter(ArrayPrinter, AbstractPythonCodePrinter): |
|
""" |
|
Tensorflow printer which handles vectorized piecewise functions, |
|
logical operators, max/min, and relational operators. |
|
""" |
|
printmethod = "_tensorflowcode" |
|
|
|
mapping = { |
|
sympy.Abs: "tensorflow.math.abs", |
|
sympy.sign: "tensorflow.math.sign", |
|
|
|
|
|
sympy.ceiling: "tensorflow.math.ceil", |
|
sympy.floor: "tensorflow.math.floor", |
|
sympy.log: "tensorflow.math.log", |
|
sympy.exp: "tensorflow.math.exp", |
|
Sqrt: "tensorflow.math.sqrt", |
|
sympy.cos: "tensorflow.math.cos", |
|
sympy.acos: "tensorflow.math.acos", |
|
sympy.sin: "tensorflow.math.sin", |
|
sympy.asin: "tensorflow.math.asin", |
|
sympy.tan: "tensorflow.math.tan", |
|
sympy.atan: "tensorflow.math.atan", |
|
sympy.atan2: "tensorflow.math.atan2", |
|
|
|
sympy.cosh: "tensorflow.math.cosh", |
|
sympy.acosh: "tensorflow.math.acosh", |
|
sympy.sinh: "tensorflow.math.sinh", |
|
sympy.asinh: "tensorflow.math.asinh", |
|
sympy.tanh: "tensorflow.math.tanh", |
|
sympy.atanh: "tensorflow.math.atanh", |
|
|
|
sympy.re: "tensorflow.math.real", |
|
sympy.im: "tensorflow.math.imag", |
|
sympy.arg: "tensorflow.math.angle", |
|
|
|
|
|
sympy.erf: "tensorflow.math.erf", |
|
sympy.loggamma: "tensorflow.math.lgamma", |
|
|
|
sympy.Eq: "tensorflow.math.equal", |
|
sympy.Ne: "tensorflow.math.not_equal", |
|
sympy.StrictGreaterThan: "tensorflow.math.greater", |
|
sympy.StrictLessThan: "tensorflow.math.less", |
|
sympy.LessThan: "tensorflow.math.less_equal", |
|
sympy.GreaterThan: "tensorflow.math.greater_equal", |
|
|
|
sympy.And: "tensorflow.math.logical_and", |
|
sympy.Or: "tensorflow.math.logical_or", |
|
sympy.Not: "tensorflow.math.logical_not", |
|
sympy.Max: "tensorflow.math.maximum", |
|
sympy.Min: "tensorflow.math.minimum", |
|
|
|
|
|
sympy.MatAdd: "tensorflow.math.add", |
|
sympy.HadamardProduct: "tensorflow.math.multiply", |
|
sympy.Trace: "tensorflow.linalg.trace", |
|
|
|
|
|
sympy.Determinant : "tensorflow.linalg.det", |
|
} |
|
|
|
_default_settings = dict( |
|
AbstractPythonCodePrinter._default_settings, |
|
tensorflow_version=None |
|
) |
|
|
|
def __init__(self, settings=None): |
|
super().__init__(settings) |
|
|
|
version = self._settings['tensorflow_version'] |
|
if version is None and tensorflow: |
|
version = tensorflow.__version__ |
|
self.tensorflow_version = version |
|
|
|
def _print_Function(self, expr): |
|
op = self.mapping.get(type(expr), None) |
|
if op is None: |
|
return super()._print_Basic(expr) |
|
children = [self._print(arg) for arg in expr.args] |
|
if len(children) == 1: |
|
return "%s(%s)" % ( |
|
self._module_format(op), |
|
children[0] |
|
) |
|
else: |
|
return self._expand_fold_binary_op(op, children) |
|
|
|
_print_Expr = _print_Function |
|
_print_Application = _print_Function |
|
_print_MatrixExpr = _print_Function |
|
|
|
_print_Relational = _print_Function |
|
_print_Not = _print_Function |
|
_print_And = _print_Function |
|
_print_Or = _print_Function |
|
_print_HadamardProduct = _print_Function |
|
_print_Trace = _print_Function |
|
_print_Determinant = _print_Function |
|
|
|
def _print_Inverse(self, expr): |
|
op = self._module_format('tensorflow.linalg.inv') |
|
return "{}({})".format(op, self._print(expr.arg)) |
|
|
|
def _print_Transpose(self, expr): |
|
version = self.tensorflow_version |
|
if version and version_tuple(version) < version_tuple('1.14'): |
|
op = self._module_format('tensorflow.matrix_transpose') |
|
else: |
|
op = self._module_format('tensorflow.linalg.matrix_transpose') |
|
return "{}({})".format(op, self._print(expr.arg)) |
|
|
|
def _print_Derivative(self, expr): |
|
variables = expr.variables |
|
if any(isinstance(i, Iterable) for i in variables): |
|
raise NotImplementedError("derivation by multiple variables is not supported") |
|
def unfold(expr, args): |
|
if not args: |
|
return self._print(expr) |
|
return "%s(%s, %s)[0]" % ( |
|
self._module_format("tensorflow.gradients"), |
|
unfold(expr, args[:-1]), |
|
self._print(args[-1]), |
|
) |
|
return unfold(expr.expr, variables) |
|
|
|
def _print_Piecewise(self, expr): |
|
version = self.tensorflow_version |
|
if version and version_tuple(version) < version_tuple('1.0'): |
|
tensorflow_piecewise = "tensorflow.select" |
|
else: |
|
tensorflow_piecewise = "tensorflow.where" |
|
|
|
from sympy.functions.elementary.piecewise import Piecewise |
|
e, cond = expr.args[0].args |
|
if len(expr.args) == 1: |
|
return '{}({}, {}, {})'.format( |
|
self._module_format(tensorflow_piecewise), |
|
self._print(cond), |
|
self._print(e), |
|
0) |
|
|
|
return '{}({}, {}, {})'.format( |
|
self._module_format(tensorflow_piecewise), |
|
self._print(cond), |
|
self._print(e), |
|
self._print(Piecewise(*expr.args[1:]))) |
|
|
|
def _print_Pow(self, expr): |
|
|
|
|
|
base, exp = expr.args |
|
if expr.exp == S.Half: |
|
return "{}({})".format( |
|
self._module_format("tensorflow.math.sqrt"), self._print(base)) |
|
return "{}({}, {})".format( |
|
self._module_format("tensorflow.math.pow"), |
|
self._print(base), self._print(exp)) |
|
|
|
def _print_MatrixBase(self, expr): |
|
tensorflow_f = "tensorflow.Variable" if expr.free_symbols else "tensorflow.constant" |
|
data = "["+", ".join(["["+", ".join([self._print(j) for j in i])+"]" for i in expr.tolist()])+"]" |
|
return "%s(%s)" % ( |
|
self._module_format(tensorflow_f), |
|
data, |
|
) |
|
|
|
def _print_MatMul(self, expr): |
|
from sympy.matrices.expressions import MatrixExpr |
|
mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)] |
|
args = [arg for arg in expr.args if arg not in mat_args] |
|
if args: |
|
return "%s*%s" % ( |
|
self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]), |
|
self._expand_fold_binary_op( |
|
"tensorflow.linalg.matmul", mat_args) |
|
) |
|
else: |
|
return self._expand_fold_binary_op( |
|
"tensorflow.linalg.matmul", mat_args) |
|
|
|
def _print_MatPow(self, expr): |
|
return self._expand_fold_binary_op( |
|
"tensorflow.linalg.matmul", [expr.base]*expr.exp) |
|
|
|
def _print_CodeBlock(self, expr): |
|
|
|
ret = [] |
|
for subexpr in expr.args: |
|
ret.append(self._print(subexpr)) |
|
return "\n".join(ret) |
|
|
|
def _print_isnan(self, exp): |
|
return f'tensorflow.math.is_nan({self._print(*exp.args)})' |
|
|
|
def _print_isinf(self, exp): |
|
return f'tensorflow.math.is_inf({self._print(*exp.args)})' |
|
|
|
_module = "tensorflow" |
|
_einsum = "linalg.einsum" |
|
_add = "math.add" |
|
_transpose = "transpose" |
|
_ones = "ones" |
|
_zeros = "zeros" |
|
|
|
|
|
def tensorflow_code(expr, **settings): |
|
printer = TensorflowPrinter(settings) |
|
return printer.doprint(expr) |
|
|