jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter
from sympy.matrices.expressions import MatrixExpr
from sympy.core.mul import Mul
from sympy.printing.precedence import PRECEDENCE
from sympy.external import import_module
from sympy.codegen.cfunctions import Sqrt
from sympy import S
from sympy import Integer
import sympy
torch = import_module('torch')
class TorchPrinter(ArrayPrinter, AbstractPythonCodePrinter):
printmethod = "_torchcode"
mapping = {
sympy.Abs: "torch.abs",
sympy.sign: "torch.sign",
# XXX May raise error for ints.
sympy.ceiling: "torch.ceil",
sympy.floor: "torch.floor",
sympy.log: "torch.log",
sympy.exp: "torch.exp",
Sqrt: "torch.sqrt",
sympy.cos: "torch.cos",
sympy.acos: "torch.acos",
sympy.sin: "torch.sin",
sympy.asin: "torch.asin",
sympy.tan: "torch.tan",
sympy.atan: "torch.atan",
sympy.atan2: "torch.atan2",
# XXX Also may give NaN for complex results.
sympy.cosh: "torch.cosh",
sympy.acosh: "torch.acosh",
sympy.sinh: "torch.sinh",
sympy.asinh: "torch.asinh",
sympy.tanh: "torch.tanh",
sympy.atanh: "torch.atanh",
sympy.Pow: "torch.pow",
sympy.re: "torch.real",
sympy.im: "torch.imag",
sympy.arg: "torch.angle",
# XXX May raise error for ints and complexes
sympy.erf: "torch.erf",
sympy.loggamma: "torch.lgamma",
sympy.Eq: "torch.eq",
sympy.Ne: "torch.ne",
sympy.StrictGreaterThan: "torch.gt",
sympy.StrictLessThan: "torch.lt",
sympy.LessThan: "torch.le",
sympy.GreaterThan: "torch.ge",
sympy.And: "torch.logical_and",
sympy.Or: "torch.logical_or",
sympy.Not: "torch.logical_not",
sympy.Max: "torch.max",
sympy.Min: "torch.min",
# Matrices
sympy.MatAdd: "torch.add",
sympy.HadamardProduct: "torch.mul",
sympy.Trace: "torch.trace",
# XXX May raise error for integer matrices.
sympy.Determinant: "torch.det",
}
_default_settings = dict(
AbstractPythonCodePrinter._default_settings,
torch_version=None,
requires_grad=False,
dtype="torch.float64",
)
def __init__(self, settings=None):
super().__init__(settings)
version = self._settings['torch_version']
self.requires_grad = self._settings['requires_grad']
self.dtype = self._settings['dtype']
if version is None and torch:
version = torch.__version__
self.torch_version = version
def _print_Function(self, expr):
op = self.mapping.get(type(expr), None)
if op is None:
return super()._print_Basic(expr)
children = [self._print(arg) for arg in expr.args]
if len(children) == 1:
return "%s(%s)" % (
self._module_format(op),
children[0]
)
else:
return self._expand_fold_binary_op(op, children)
# mirrors the tensorflow version
_print_Expr = _print_Function
_print_Application = _print_Function
_print_MatrixExpr = _print_Function
_print_Relational = _print_Function
_print_Not = _print_Function
_print_And = _print_Function
_print_Or = _print_Function
_print_HadamardProduct = _print_Function
_print_Trace = _print_Function
_print_Determinant = _print_Function
def _print_Inverse(self, expr):
return '{}({})'.format(self._module_format("torch.linalg.inv"),
self._print(expr.args[0]))
def _print_Transpose(self, expr):
if expr.arg.is_Matrix and expr.arg.shape[0] == expr.arg.shape[1]:
# For square matrices, we can use the .t() method
return "{}({}).t()".format("torch.transpose", self._print(expr.arg))
else:
# For non-square matrices or more general cases
# transpose first and second dimensions (typical matrix transpose)
return "{}.permute({})".format(
self._print(expr.arg),
", ".join([str(i) for i in range(len(expr.arg.shape))])[::-1]
)
def _print_PermuteDims(self, expr):
return "%s.permute(%s)" % (
self._print(expr.expr),
", ".join(str(i) for i in expr.permutation.array_form)
)
def _print_Derivative(self, expr):
# this version handles multi-variable and mixed partial derivatives. The tensorflow version does not.
variables = expr.variables
expr_arg = expr.expr
# Handle multi-variable or repeated derivatives
if len(variables) > 1 or (
len(variables) == 1 and not isinstance(variables[0], tuple) and variables.count(variables[0]) > 1):
result = self._print(expr_arg)
var_groups = {}
# Group variables by base symbol
for var in variables:
if isinstance(var, tuple):
base_var, order = var
var_groups[base_var] = var_groups.get(base_var, 0) + order
else:
var_groups[var] = var_groups.get(var, 0) + 1
# Apply gradients in sequence
for var, order in var_groups.items():
for _ in range(order):
result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(var))
return result
# Handle single variable case
if len(variables) == 1:
variable = variables[0]
if isinstance(variable, tuple) and len(variable) == 2:
base_var, order = variable
if not isinstance(order, Integer): raise NotImplementedError("Only integer orders are supported")
result = self._print(expr_arg)
for _ in range(order):
result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(base_var))
return result
return "torch.autograd.grad({}, {})[0]".format(self._print(expr_arg), self._print(variable))
return self._print(expr_arg) # Empty variables case
def _print_Piecewise(self, expr):
from sympy import Piecewise
e, cond = expr.args[0].args
if len(expr.args) == 1:
return '{}({}, {}, {})'.format(
self._module_format("torch.where"),
self._print(cond),
self._print(e),
0)
return '{}({}, {}, {})'.format(
self._module_format("torch.where"),
self._print(cond),
self._print(e),
self._print(Piecewise(*expr.args[1:])))
def _print_Pow(self, expr):
# XXX May raise error for
# int**float or int**complex or float**complex
base, exp = expr.args
if expr.exp == S.Half:
return "{}({})".format(
self._module_format("torch.sqrt"), self._print(base))
return "{}({}, {})".format(
self._module_format("torch.pow"),
self._print(base), self._print(exp))
def _print_MatMul(self, expr):
# Separate matrix and scalar arguments
mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)]
args = [arg for arg in expr.args if arg not in mat_args]
# Handle scalar multipliers if present
if args:
return "%s*%s" % (
self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]),
self._expand_fold_binary_op("torch.matmul", mat_args)
)
else:
return self._expand_fold_binary_op("torch.matmul", mat_args)
def _print_MatPow(self, expr):
return self._expand_fold_binary_op("torch.mm", [expr.base]*expr.exp)
def _print_MatrixBase(self, expr):
data = "[" + ", ".join(["[" + ", ".join([self._print(j) for j in i]) + "]" for i in expr.tolist()]) + "]"
params = [str(data)]
params.append(f"dtype={self.dtype}")
if self.requires_grad:
params.append("requires_grad=True")
return "{}({})".format(
self._module_format("torch.tensor"),
", ".join(params)
)
def _print_isnan(self, expr):
return f'torch.isnan({self._print(expr.args[0])})'
def _print_isinf(self, expr):
return f'torch.isinf({self._print(expr.args[0])})'
def _print_Identity(self, expr):
if all(dim.is_Integer for dim in expr.shape):
return "{}({})".format(
self._module_format("torch.eye"),
self._print(expr.shape[0])
)
else:
# For symbolic dimensions, fall back to a more general approach
return "{}({}, {})".format(
self._module_format("torch.eye"),
self._print(expr.shape[0]),
self._print(expr.shape[1])
)
def _print_ZeroMatrix(self, expr):
return "{}({})".format(
self._module_format("torch.zeros"),
self._print(expr.shape)
)
def _print_OneMatrix(self, expr):
return "{}({})".format(
self._module_format("torch.ones"),
self._print(expr.shape)
)
def _print_conjugate(self, expr):
return f"{self._module_format('torch.conj')}({self._print(expr.args[0])})"
def _print_ImaginaryUnit(self, expr):
return "1j" # uses the Python built-in 1j notation for the imaginary unit
def _print_Heaviside(self, expr):
args = [self._print(expr.args[0]), "0.5"]
if len(expr.args) > 1:
args[1] = self._print(expr.args[1])
return f"{self._module_format('torch.heaviside')}({args[0]}, {args[1]})"
def _print_gamma(self, expr):
return f"{self._module_format('torch.special.gamma')}({self._print(expr.args[0])})"
def _print_polygamma(self, expr):
if expr.args[0] == S.Zero:
return f"{self._module_format('torch.special.digamma')}({self._print(expr.args[1])})"
else:
raise NotImplementedError("PyTorch only supports digamma (0th order polygamma)")
_module = "torch"
_einsum = "einsum"
_add = "add"
_transpose = "t"
_ones = "ones"
_zeros = "zeros"
def torch_code(expr, requires_grad=False, dtype="torch.float64", **settings):
printer = TorchPrinter(settings={'requires_grad': requires_grad, 'dtype': dtype})
return printer.doprint(expr, **settings)