|
from sympy.core import (S, pi, oo, symbols, Rational, Integer, |
|
GoldenRatio, EulerGamma, Catalan, Lambda, Dummy, |
|
Eq, Ne, Le, Lt, Gt, Ge, Mod) |
|
from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, |
|
sign, floor) |
|
from sympy.logic import ITE |
|
from sympy.testing.pytest import raises |
|
from sympy.utilities.lambdify import implemented_function |
|
from sympy.tensor import IndexedBase, Idx |
|
from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix |
|
|
|
from sympy.printing.codeprinter import rust_code |
|
|
|
x, y, z = symbols('x,y,z', integer=False, real=True) |
|
k, m, n = symbols('k,m,n', integer=True) |
|
|
|
|
|
def test_Integer(): |
|
assert rust_code(Integer(42)) == "42" |
|
assert rust_code(Integer(-56)) == "-56" |
|
|
|
|
|
def test_Relational(): |
|
assert rust_code(Eq(x, y)) == "x == y" |
|
assert rust_code(Ne(x, y)) == "x != y" |
|
assert rust_code(Le(x, y)) == "x <= y" |
|
assert rust_code(Lt(x, y)) == "x < y" |
|
assert rust_code(Gt(x, y)) == "x > y" |
|
assert rust_code(Ge(x, y)) == "x >= y" |
|
|
|
|
|
def test_Rational(): |
|
assert rust_code(Rational(3, 7)) == "3_f64/7.0" |
|
assert rust_code(Rational(18, 9)) == "2" |
|
assert rust_code(Rational(3, -7)) == "-3_f64/7.0" |
|
assert rust_code(Rational(-3, -7)) == "3_f64/7.0" |
|
assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0" |
|
assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x" |
|
|
|
|
|
def test_basic_ops(): |
|
assert rust_code(x + y) == "x + y" |
|
assert rust_code(x - y) == "x - y" |
|
assert rust_code(x * y) == "x*y" |
|
assert rust_code(x / y) == "x*y.recip()" |
|
assert rust_code(-x) == "-x" |
|
assert rust_code(2 * x) == "2.0*x" |
|
assert rust_code(y + 2) == "y + 2.0" |
|
assert rust_code(x + n) == "n as f64 + x" |
|
|
|
def test_printmethod(): |
|
class fabs(Abs): |
|
def _rust_code(self, printer): |
|
return "%s.fabs()" % printer._print(self.args[0]) |
|
assert rust_code(fabs(x)) == "x.fabs()" |
|
a = MatrixSymbol("a", 1, 3) |
|
assert rust_code(a[0,0]) == 'a[0]' |
|
|
|
|
|
def test_Functions(): |
|
assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())" |
|
assert rust_code(abs(x)) == "x.abs()" |
|
assert rust_code(ceiling(x)) == "x.ceil()" |
|
assert rust_code(floor(x)) == "x.floor()" |
|
|
|
|
|
assert rust_code(Mod(x, 3)) == 'x - 3.0*((1_f64/3.0)*x).floor()' |
|
|
|
|
|
def test_Pow(): |
|
assert rust_code(1/x) == "x.recip()" |
|
assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()" |
|
assert rust_code(sqrt(x)) == "x.sqrt()" |
|
assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()" |
|
|
|
assert rust_code(1/sqrt(x)) == "x.sqrt().recip()" |
|
assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()" |
|
|
|
assert rust_code(1/pi) == "PI.recip()" |
|
assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()" |
|
assert rust_code(pi**-0.5) == "PI.sqrt().recip()" |
|
|
|
assert rust_code(x**Rational(1, 3)) == "x.cbrt()" |
|
assert rust_code(2**x) == "x.exp2()" |
|
assert rust_code(exp(x)) == "x.exp()" |
|
assert rust_code(x**3) == "x.powi(3)" |
|
assert rust_code(x**(y**3)) == "x.powf(y.powi(3))" |
|
assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)" |
|
|
|
g = implemented_function('g', Lambda(x, 2*x)) |
|
assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ |
|
"(3.5*2.0*x).powf(-x + y.powf(x))/(x.powi(2) + y)" |
|
_cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1), |
|
(lambda base, exp: not exp.is_integer, "pow", 1)] |
|
assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)' |
|
assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)' |
|
|
|
|
|
def test_constants(): |
|
assert rust_code(pi) == "PI" |
|
assert rust_code(oo) == "INFINITY" |
|
assert rust_code(S.Infinity) == "INFINITY" |
|
assert rust_code(-oo) == "NEG_INFINITY" |
|
assert rust_code(S.NegativeInfinity) == "NEG_INFINITY" |
|
assert rust_code(S.NaN) == "NAN" |
|
assert rust_code(exp(1)) == "E" |
|
assert rust_code(S.Exp1) == "E" |
|
|
|
|
|
def test_constants_other(): |
|
assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2.0*GoldenRatio" % GoldenRatio.evalf(17) |
|
assert rust_code( |
|
2*Catalan) == "const Catalan: f64 = %s;\n2.0*Catalan" % Catalan.evalf(17) |
|
assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2.0*EulerGamma" % EulerGamma.evalf(17) |
|
|
|
|
|
def test_boolean(): |
|
assert rust_code(True) == "true" |
|
assert rust_code(S.true) == "true" |
|
assert rust_code(False) == "false" |
|
assert rust_code(S.false) == "false" |
|
assert rust_code(k & m) == "k && m" |
|
assert rust_code(k | m) == "k || m" |
|
assert rust_code(~k) == "!k" |
|
assert rust_code(k & m & n) == "k && m && n" |
|
assert rust_code(k | m | n) == "k || m || n" |
|
assert rust_code((k & m) | n) == "n || k && m" |
|
assert rust_code((k | m) & n) == "n && (k || m)" |
|
|
|
|
|
def test_Piecewise(): |
|
expr = Piecewise((x, x < 1), (x + 2, True)) |
|
assert rust_code(expr) == ( |
|
"if (x < 1.0) {\n" |
|
" x\n" |
|
"} else {\n" |
|
" x + 2.0\n" |
|
"}") |
|
assert rust_code(expr, assign_to="r") == ( |
|
"r = if (x < 1.0) {\n" |
|
" x\n" |
|
"} else {\n" |
|
" x + 2.0\n" |
|
"};") |
|
assert rust_code(expr, assign_to="r", inline=True) == ( |
|
"r = if (x < 1.0) { x } else { x + 2.0 };") |
|
expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) |
|
assert rust_code(expr, inline=True) == ( |
|
"if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }") |
|
assert rust_code(expr, assign_to="r", inline=True) == ( |
|
"r = if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 };") |
|
assert rust_code(expr, assign_to="r") == ( |
|
"r = if (x < 1.0) {\n" |
|
" x\n" |
|
"} else if (x < 5.0) {\n" |
|
" x + 1.0\n" |
|
"} else {\n" |
|
" x + 2.0\n" |
|
"};") |
|
expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) |
|
assert rust_code(expr, inline=True) == ( |
|
"2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }") |
|
expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42 |
|
assert rust_code(expr, inline=True) == ( |
|
"2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 } - 42.0") |
|
|
|
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) |
|
raises(ValueError, lambda: rust_code(expr)) |
|
|
|
|
|
def test_dereference_printing(): |
|
expr = x + y + sin(z) + z |
|
assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()" |
|
|
|
|
|
def test_sign(): |
|
expr = sign(x) * y |
|
assert rust_code(expr) == "y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64" |
|
assert rust_code(expr, assign_to='r') == "r = y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64;" |
|
|
|
expr = sign(x + y) + 42 |
|
assert rust_code(expr) == "(if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42" |
|
assert rust_code(expr, assign_to='r') == "r = (if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42;" |
|
|
|
expr = sign(cos(x)) |
|
assert rust_code(expr) == "(if (x.cos() == 0.0) { 0.0 } else { (x.cos()).signum() })" |
|
|
|
|
|
def test_reserved_words(): |
|
|
|
x, y = symbols("x if") |
|
|
|
expr = sin(y) |
|
assert rust_code(expr) == "if_.sin()" |
|
assert rust_code(expr, dereference=[y]) == "(*if_).sin()" |
|
assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()" |
|
|
|
with raises(ValueError): |
|
rust_code(expr, error_on_reserved=True) |
|
|
|
|
|
def test_ITE(): |
|
ekpr = ITE(k < 1, m, n) |
|
assert rust_code(ekpr) == ( |
|
"if (k < 1) {\n" |
|
" m\n" |
|
"} else {\n" |
|
" n\n" |
|
"}") |
|
|
|
|
|
def test_Indexed(): |
|
n, m, o = symbols('n m o', integer=True) |
|
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) |
|
|
|
x = IndexedBase('x')[j] |
|
assert rust_code(x) == "x[j]" |
|
|
|
A = IndexedBase('A')[i, j] |
|
assert rust_code(A) == "A[m*i + j]" |
|
|
|
B = IndexedBase('B')[i, j, k] |
|
assert rust_code(B) == "B[m*o*i + o*j + k]" |
|
|
|
|
|
def test_dummy_loops(): |
|
i, m = symbols('i m', integer=True, cls=Dummy) |
|
x = IndexedBase('x') |
|
y = IndexedBase('y') |
|
i = Idx(i, m) |
|
|
|
assert rust_code(x[i], assign_to=y[i]) == ( |
|
"for i in 0..m {\n" |
|
" y[i] = x[i];\n" |
|
"}") |
|
|
|
|
|
def test_loops(): |
|
m, n = symbols('m n', integer=True) |
|
A = IndexedBase('A') |
|
x = IndexedBase('x') |
|
y = IndexedBase('y') |
|
z = IndexedBase('z') |
|
i = Idx('i', m) |
|
j = Idx('j', n) |
|
|
|
assert rust_code(A[i, j]*x[j], assign_to=y[i]) == ( |
|
"for i in 0..m {\n" |
|
" y[i] = 0;\n" |
|
"}\n" |
|
"for i in 0..m {\n" |
|
" for j in 0..n {\n" |
|
" y[i] = A[n*i + j]*x[j] + y[i];\n" |
|
" }\n" |
|
"}") |
|
|
|
assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == ( |
|
"for i in 0..m {\n" |
|
" y[i] = x[i] + z[i];\n" |
|
"}\n" |
|
"for i in 0..m {\n" |
|
" for j in 0..n {\n" |
|
" y[i] = A[n*i + j]*x[j] + y[i];\n" |
|
" }\n" |
|
"}") |
|
|
|
|
|
def test_loops_multiple_contractions(): |
|
n, m, o, p = symbols('n m o p', integer=True) |
|
a = IndexedBase('a') |
|
b = IndexedBase('b') |
|
y = IndexedBase('y') |
|
i = Idx('i', m) |
|
j = Idx('j', n) |
|
k = Idx('k', o) |
|
l = Idx('l', p) |
|
|
|
assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == ( |
|
"for i in 0..m {\n" |
|
" y[i] = 0;\n" |
|
"}\n" |
|
"for i in 0..m {\n" |
|
" for j in 0..n {\n" |
|
" for k in 0..o {\n" |
|
" for l in 0..p {\n" |
|
" y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ |
|
" }\n" |
|
" }\n" |
|
" }\n" |
|
"}") |
|
|
|
|
|
def test_loops_addfactor(): |
|
m, n, o, p = symbols('m n o p', integer=True) |
|
a = IndexedBase('a') |
|
b = IndexedBase('b') |
|
c = IndexedBase('c') |
|
y = IndexedBase('y') |
|
i = Idx('i', m) |
|
j = Idx('j', n) |
|
k = Idx('k', o) |
|
l = Idx('l', p) |
|
|
|
code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) |
|
assert code == ( |
|
"for i in 0..m {\n" |
|
" y[i] = 0;\n" |
|
"}\n" |
|
"for i in 0..m {\n" |
|
" for j in 0..n {\n" |
|
" for k in 0..o {\n" |
|
" for l in 0..p {\n" |
|
" y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ |
|
" }\n" |
|
" }\n" |
|
" }\n" |
|
"}") |
|
|
|
|
|
def test_settings(): |
|
raises(TypeError, lambda: rust_code(sin(x), method="garbage")) |
|
|
|
|
|
def test_inline_function(): |
|
x = symbols('x') |
|
g = implemented_function('g', Lambda(x, 2*x)) |
|
assert rust_code(g(x)) == "2*x" |
|
|
|
g = implemented_function('g', Lambda(x, 2*x/Catalan)) |
|
assert rust_code(g(x)) == ( |
|
"const Catalan: f64 = %s;\n2.0*x/Catalan" % Catalan.evalf(17)) |
|
|
|
A = IndexedBase('A') |
|
i = Idx('i', symbols('n', integer=True)) |
|
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) |
|
assert rust_code(g(A[i]), assign_to=A[i]) == ( |
|
"for i in 0..n {\n" |
|
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n" |
|
"}") |
|
|
|
|
|
def test_user_functions(): |
|
x = symbols('x', integer=False) |
|
n = symbols('n', integer=True) |
|
custom_functions = { |
|
"ceiling": "ceil", |
|
"Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)], |
|
} |
|
assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()" |
|
assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)" |
|
assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)" |
|
|
|
|
|
def test_matrix(): |
|
assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]' |
|
with raises(ValueError): |
|
rust_code(Matrix([[1, 2, 3]])) |
|
|
|
|
|
def test_sparse_matrix(): |
|
|
|
with raises(NotImplementedError): |
|
rust_code(SparseMatrix([[1, 2, 3]])) |
|
|