|
|
|
from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter |
|
from sympy.matrices.expressions import MatrixExpr |
|
from sympy.core.mul import Mul |
|
from sympy.printing.precedence import PRECEDENCE |
|
from sympy.external import import_module |
|
from sympy.codegen.cfunctions import Sqrt |
|
from sympy import S |
|
from sympy import Integer |
|
|
|
import sympy |
|
|
|
torch = import_module('torch') |
|
|
|
|
|
class TorchPrinter(ArrayPrinter, AbstractPythonCodePrinter): |
|
|
|
printmethod = "_torchcode" |
|
|
|
mapping = { |
|
sympy.Abs: "torch.abs", |
|
sympy.sign: "torch.sign", |
|
|
|
|
|
sympy.ceiling: "torch.ceil", |
|
sympy.floor: "torch.floor", |
|
sympy.log: "torch.log", |
|
sympy.exp: "torch.exp", |
|
Sqrt: "torch.sqrt", |
|
sympy.cos: "torch.cos", |
|
sympy.acos: "torch.acos", |
|
sympy.sin: "torch.sin", |
|
sympy.asin: "torch.asin", |
|
sympy.tan: "torch.tan", |
|
sympy.atan: "torch.atan", |
|
sympy.atan2: "torch.atan2", |
|
|
|
sympy.cosh: "torch.cosh", |
|
sympy.acosh: "torch.acosh", |
|
sympy.sinh: "torch.sinh", |
|
sympy.asinh: "torch.asinh", |
|
sympy.tanh: "torch.tanh", |
|
sympy.atanh: "torch.atanh", |
|
sympy.Pow: "torch.pow", |
|
|
|
sympy.re: "torch.real", |
|
sympy.im: "torch.imag", |
|
sympy.arg: "torch.angle", |
|
|
|
|
|
sympy.erf: "torch.erf", |
|
sympy.loggamma: "torch.lgamma", |
|
|
|
sympy.Eq: "torch.eq", |
|
sympy.Ne: "torch.ne", |
|
sympy.StrictGreaterThan: "torch.gt", |
|
sympy.StrictLessThan: "torch.lt", |
|
sympy.LessThan: "torch.le", |
|
sympy.GreaterThan: "torch.ge", |
|
|
|
sympy.And: "torch.logical_and", |
|
sympy.Or: "torch.logical_or", |
|
sympy.Not: "torch.logical_not", |
|
sympy.Max: "torch.max", |
|
sympy.Min: "torch.min", |
|
|
|
|
|
sympy.MatAdd: "torch.add", |
|
sympy.HadamardProduct: "torch.mul", |
|
sympy.Trace: "torch.trace", |
|
|
|
|
|
sympy.Determinant: "torch.det", |
|
} |
|
|
|
_default_settings = dict( |
|
AbstractPythonCodePrinter._default_settings, |
|
torch_version=None, |
|
requires_grad=False, |
|
dtype="torch.float64", |
|
) |
|
|
|
def __init__(self, settings=None): |
|
super().__init__(settings) |
|
|
|
version = self._settings['torch_version'] |
|
self.requires_grad = self._settings['requires_grad'] |
|
self.dtype = self._settings['dtype'] |
|
if version is None and torch: |
|
version = torch.__version__ |
|
self.torch_version = version |
|
|
|
def _print_Function(self, expr): |
|
|
|
op = self.mapping.get(type(expr), None) |
|
if op is None: |
|
return super()._print_Basic(expr) |
|
children = [self._print(arg) for arg in expr.args] |
|
if len(children) == 1: |
|
return "%s(%s)" % ( |
|
self._module_format(op), |
|
children[0] |
|
) |
|
else: |
|
return self._expand_fold_binary_op(op, children) |
|
|
|
|
|
_print_Expr = _print_Function |
|
_print_Application = _print_Function |
|
_print_MatrixExpr = _print_Function |
|
_print_Relational = _print_Function |
|
_print_Not = _print_Function |
|
_print_And = _print_Function |
|
_print_Or = _print_Function |
|
_print_HadamardProduct = _print_Function |
|
_print_Trace = _print_Function |
|
_print_Determinant = _print_Function |
|
|
|
def _print_Inverse(self, expr): |
|
return '{}({})'.format(self._module_format("torch.linalg.inv"), |
|
self._print(expr.args[0])) |
|
|
|
def _print_Transpose(self, expr): |
|
if expr.arg.is_Matrix and expr.arg.shape[0] == expr.arg.shape[1]: |
|
|
|
return "{}({}).t()".format("torch.transpose", self._print(expr.arg)) |
|
else: |
|
|
|
|
|
return "{}.permute({})".format( |
|
self._print(expr.arg), |
|
", ".join([str(i) for i in range(len(expr.arg.shape))])[::-1] |
|
) |
|
|
|
def _print_PermuteDims(self, expr): |
|
return "%s.permute(%s)" % ( |
|
self._print(expr.expr), |
|
", ".join(str(i) for i in expr.permutation.array_form) |
|
) |
|
|
|
def _print_Derivative(self, expr): |
|
|
|
variables = expr.variables |
|
expr_arg = expr.expr |
|
|
|
|
|
if len(variables) > 1 or ( |
|
len(variables) == 1 and not isinstance(variables[0], tuple) and variables.count(variables[0]) > 1): |
|
result = self._print(expr_arg) |
|
var_groups = {} |
|
|
|
|
|
for var in variables: |
|
if isinstance(var, tuple): |
|
base_var, order = var |
|
var_groups[base_var] = var_groups.get(base_var, 0) + order |
|
else: |
|
var_groups[var] = var_groups.get(var, 0) + 1 |
|
|
|
|
|
for var, order in var_groups.items(): |
|
for _ in range(order): |
|
result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(var)) |
|
return result |
|
|
|
|
|
if len(variables) == 1: |
|
variable = variables[0] |
|
if isinstance(variable, tuple) and len(variable) == 2: |
|
base_var, order = variable |
|
if not isinstance(order, Integer): raise NotImplementedError("Only integer orders are supported") |
|
result = self._print(expr_arg) |
|
for _ in range(order): |
|
result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(base_var)) |
|
return result |
|
return "torch.autograd.grad({}, {})[0]".format(self._print(expr_arg), self._print(variable)) |
|
|
|
return self._print(expr_arg) |
|
|
|
def _print_Piecewise(self, expr): |
|
from sympy import Piecewise |
|
e, cond = expr.args[0].args |
|
if len(expr.args) == 1: |
|
return '{}({}, {}, {})'.format( |
|
self._module_format("torch.where"), |
|
self._print(cond), |
|
self._print(e), |
|
0) |
|
|
|
return '{}({}, {}, {})'.format( |
|
self._module_format("torch.where"), |
|
self._print(cond), |
|
self._print(e), |
|
self._print(Piecewise(*expr.args[1:]))) |
|
|
|
def _print_Pow(self, expr): |
|
|
|
|
|
base, exp = expr.args |
|
if expr.exp == S.Half: |
|
return "{}({})".format( |
|
self._module_format("torch.sqrt"), self._print(base)) |
|
return "{}({}, {})".format( |
|
self._module_format("torch.pow"), |
|
self._print(base), self._print(exp)) |
|
|
|
def _print_MatMul(self, expr): |
|
|
|
mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)] |
|
args = [arg for arg in expr.args if arg not in mat_args] |
|
|
|
if args: |
|
return "%s*%s" % ( |
|
self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]), |
|
self._expand_fold_binary_op("torch.matmul", mat_args) |
|
) |
|
else: |
|
return self._expand_fold_binary_op("torch.matmul", mat_args) |
|
|
|
def _print_MatPow(self, expr): |
|
return self._expand_fold_binary_op("torch.mm", [expr.base]*expr.exp) |
|
|
|
def _print_MatrixBase(self, expr): |
|
data = "[" + ", ".join(["[" + ", ".join([self._print(j) for j in i]) + "]" for i in expr.tolist()]) + "]" |
|
params = [str(data)] |
|
params.append(f"dtype={self.dtype}") |
|
if self.requires_grad: |
|
params.append("requires_grad=True") |
|
|
|
return "{}({})".format( |
|
self._module_format("torch.tensor"), |
|
", ".join(params) |
|
) |
|
|
|
def _print_isnan(self, expr): |
|
return f'torch.isnan({self._print(expr.args[0])})' |
|
|
|
def _print_isinf(self, expr): |
|
return f'torch.isinf({self._print(expr.args[0])})' |
|
|
|
def _print_Identity(self, expr): |
|
if all(dim.is_Integer for dim in expr.shape): |
|
return "{}({})".format( |
|
self._module_format("torch.eye"), |
|
self._print(expr.shape[0]) |
|
) |
|
else: |
|
|
|
return "{}({}, {})".format( |
|
self._module_format("torch.eye"), |
|
self._print(expr.shape[0]), |
|
self._print(expr.shape[1]) |
|
) |
|
|
|
def _print_ZeroMatrix(self, expr): |
|
return "{}({})".format( |
|
self._module_format("torch.zeros"), |
|
self._print(expr.shape) |
|
) |
|
|
|
def _print_OneMatrix(self, expr): |
|
return "{}({})".format( |
|
self._module_format("torch.ones"), |
|
self._print(expr.shape) |
|
) |
|
|
|
def _print_conjugate(self, expr): |
|
return f"{self._module_format('torch.conj')}({self._print(expr.args[0])})" |
|
|
|
def _print_ImaginaryUnit(self, expr): |
|
return "1j" |
|
|
|
def _print_Heaviside(self, expr): |
|
args = [self._print(expr.args[0]), "0.5"] |
|
if len(expr.args) > 1: |
|
args[1] = self._print(expr.args[1]) |
|
return f"{self._module_format('torch.heaviside')}({args[0]}, {args[1]})" |
|
|
|
def _print_gamma(self, expr): |
|
return f"{self._module_format('torch.special.gamma')}({self._print(expr.args[0])})" |
|
|
|
def _print_polygamma(self, expr): |
|
if expr.args[0] == S.Zero: |
|
return f"{self._module_format('torch.special.digamma')}({self._print(expr.args[1])})" |
|
else: |
|
raise NotImplementedError("PyTorch only supports digamma (0th order polygamma)") |
|
|
|
_module = "torch" |
|
_einsum = "einsum" |
|
_add = "add" |
|
_transpose = "t" |
|
_ones = "ones" |
|
_zeros = "zeros" |
|
|
|
def torch_code(expr, requires_grad=False, dtype="torch.float64", **settings): |
|
printer = TorchPrinter(settings={'requires_grad': requires_grad, 'dtype': dtype}) |
|
return printer.doprint(expr, **settings) |
|
|