|
"""Tests of tools for setting up interactive IPython sessions. """ |
|
|
|
from sympy.interactive.session import (init_ipython_session, |
|
enable_automatic_symbols, enable_automatic_int_sympification) |
|
|
|
from sympy.core import Symbol, Rational, Integer |
|
from sympy.external import import_module |
|
from sympy.testing.pytest import raises |
|
|
|
|
|
|
|
|
|
|
|
|
|
ipython = import_module("IPython", min_module_version="1.0") |
|
|
|
if not ipython: |
|
|
|
disabled = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_automatic_symbols(): |
|
|
|
|
|
|
|
app = init_ipython_session() |
|
app.run_cell("from sympy import *") |
|
|
|
enable_automatic_symbols(app) |
|
|
|
symbol = "verylongsymbolname" |
|
assert symbol not in app.user_ns |
|
app.run_cell("a = %s" % symbol, True) |
|
assert symbol not in app.user_ns |
|
app.run_cell("a = type(%s)" % symbol, True) |
|
assert app.user_ns['a'] == Symbol |
|
app.run_cell("%s = Symbol('%s')" % (symbol, symbol), True) |
|
assert symbol in app.user_ns |
|
|
|
|
|
app.run_cell("a = all == __builtin__.all", True) |
|
assert "all" not in app.user_ns |
|
assert app.user_ns['a'] is True |
|
|
|
|
|
app.run_cell("import sympy") |
|
app.run_cell("a = factorial == sympy.factorial", True) |
|
assert app.user_ns['a'] is True |
|
|
|
|
|
def test_int_to_Integer(): |
|
|
|
app = init_ipython_session() |
|
app.run_cell("from sympy import Integer") |
|
app.run_cell("a = 1") |
|
assert isinstance(app.user_ns['a'], int) |
|
|
|
enable_automatic_int_sympification(app) |
|
app.run_cell("a = 1/2") |
|
assert isinstance(app.user_ns['a'], Rational) |
|
app.run_cell("a = 1") |
|
assert isinstance(app.user_ns['a'], Integer) |
|
app.run_cell("a = int(1)") |
|
assert isinstance(app.user_ns['a'], int) |
|
app.run_cell("a = (1/\n2)") |
|
assert app.user_ns['a'] == Rational(1, 2) |
|
|
|
|
|
|
|
|
|
def test_ipythonprinting(): |
|
|
|
app = init_ipython_session() |
|
app.run_cell("ip = get_ipython()") |
|
app.run_cell("inst = ip.instance()") |
|
app.run_cell("format = inst.display_formatter.format") |
|
app.run_cell("from sympy import Symbol") |
|
|
|
|
|
app.run_cell("a = format(Symbol('pi'))") |
|
app.run_cell("a2 = format(Symbol('pi')**2)") |
|
|
|
if int(ipython.__version__.split(".")[0]) < 1: |
|
assert app.user_ns['a']['text/plain'] == "pi" |
|
assert app.user_ns['a2']['text/plain'] == "pi**2" |
|
else: |
|
assert app.user_ns['a'][0]['text/plain'] == "pi" |
|
assert app.user_ns['a2'][0]['text/plain'] == "pi**2" |
|
|
|
|
|
app.run_cell("from sympy import init_printing") |
|
app.run_cell("init_printing()") |
|
|
|
app.run_cell("a = format(Symbol('pi'))") |
|
app.run_cell("a2 = format(Symbol('pi')**2)") |
|
|
|
if int(ipython.__version__.split(".")[0]) < 1: |
|
assert app.user_ns['a']['text/plain'] in ('\N{GREEK SMALL LETTER PI}', 'pi') |
|
assert app.user_ns['a2']['text/plain'] in (' 2\n\N{GREEK SMALL LETTER PI} ', ' 2\npi ') |
|
else: |
|
assert app.user_ns['a'][0]['text/plain'] in ('\N{GREEK SMALL LETTER PI}', 'pi') |
|
assert app.user_ns['a2'][0]['text/plain'] in (' 2\n\N{GREEK SMALL LETTER PI} ', ' 2\npi ') |
|
|
|
|
|
def test_print_builtin_option(): |
|
|
|
app = init_ipython_session() |
|
app.run_cell("ip = get_ipython()") |
|
app.run_cell("inst = ip.instance()") |
|
app.run_cell("format = inst.display_formatter.format") |
|
app.run_cell("from sympy import Symbol") |
|
app.run_cell("from sympy import init_printing") |
|
|
|
app.run_cell("a = format({Symbol('pi'): 3.14, Symbol('n_i'): 3})") |
|
|
|
if int(ipython.__version__.split(".")[0]) < 1: |
|
text = app.user_ns['a']['text/plain'] |
|
raises(KeyError, lambda: app.user_ns['a']['text/latex']) |
|
else: |
|
text = app.user_ns['a'][0]['text/plain'] |
|
raises(KeyError, lambda: app.user_ns['a'][0]['text/latex']) |
|
|
|
|
|
assert text in ("{pi: 3.14, n_i: 3}", |
|
'{n\N{LATIN SUBSCRIPT SMALL LETTER I}: 3, \N{GREEK SMALL LETTER PI}: 3.14}', |
|
"{n_i: 3, pi: 3.14}", |
|
'{\N{GREEK SMALL LETTER PI}: 3.14, n\N{LATIN SUBSCRIPT SMALL LETTER I}: 3}') |
|
|
|
|
|
|
|
app.run_cell("inst.display_formatter.formatters['text/latex'].enabled = True") |
|
app.run_cell("init_printing(use_latex=True)") |
|
app.run_cell("a = format({Symbol('pi'): 3.14, Symbol('n_i'): 3})") |
|
|
|
if int(ipython.__version__.split(".")[0]) < 1: |
|
text = app.user_ns['a']['text/plain'] |
|
latex = app.user_ns['a']['text/latex'] |
|
else: |
|
text = app.user_ns['a'][0]['text/plain'] |
|
latex = app.user_ns['a'][0]['text/latex'] |
|
assert text in ("{pi: 3.14, n_i: 3}", |
|
'{n\N{LATIN SUBSCRIPT SMALL LETTER I}: 3, \N{GREEK SMALL LETTER PI}: 3.14}', |
|
"{n_i: 3, pi: 3.14}", |
|
'{\N{GREEK SMALL LETTER PI}: 3.14, n\N{LATIN SUBSCRIPT SMALL LETTER I}: 3}') |
|
assert latex == r'$\displaystyle \left\{ n_{i} : 3, \ \pi : 3.14\right\}$' |
|
|
|
|
|
|
|
app.run_cell("""\ |
|
class WithOverload: |
|
def _latex(self, printer): |
|
return r"\\LaTeX" |
|
""") |
|
app.run_cell("a = format((WithOverload(),))") |
|
|
|
if int(ipython.__version__.split(".")[0]) < 1: |
|
latex = app.user_ns['a']['text/latex'] |
|
else: |
|
latex = app.user_ns['a'][0]['text/latex'] |
|
assert latex == r'$\displaystyle \left( \LaTeX,\right)$' |
|
|
|
app.run_cell("inst.display_formatter.formatters['text/latex'].enabled = True") |
|
app.run_cell("init_printing(use_latex=True, print_builtin=False)") |
|
app.run_cell("a = format({Symbol('pi'): 3.14, Symbol('n_i'): 3})") |
|
|
|
if int(ipython.__version__.split(".")[0]) < 1: |
|
text = app.user_ns['a']['text/plain'] |
|
raises(KeyError, lambda: app.user_ns['a']['text/latex']) |
|
else: |
|
text = app.user_ns['a'][0]['text/plain'] |
|
raises(KeyError, lambda: app.user_ns['a'][0]['text/latex']) |
|
|
|
|
|
|
|
|
|
assert text in ("{pi: 3.14, n_i: 3}", "{n_i: 3, pi: 3.14}") |
|
|
|
|
|
def test_builtin_containers(): |
|
|
|
app = init_ipython_session() |
|
app.run_cell("ip = get_ipython()") |
|
app.run_cell("inst = ip.instance()") |
|
app.run_cell("format = inst.display_formatter.format") |
|
app.run_cell("inst.display_formatter.formatters['text/latex'].enabled = True") |
|
app.run_cell("from sympy import init_printing, Matrix") |
|
app.run_cell('init_printing(use_latex=True, use_unicode=False)') |
|
|
|
|
|
app.run_cell('a = format((True, False))') |
|
app.run_cell('import sys') |
|
app.run_cell('b = format(sys.flags)') |
|
app.run_cell('c = format((Matrix([1, 2]),))') |
|
|
|
if int(ipython.__version__.split(".")[0]) < 1: |
|
assert app.user_ns['a']['text/plain'] == '(True, False)' |
|
assert 'text/latex' not in app.user_ns['a'] |
|
assert app.user_ns['b']['text/plain'][:10] == 'sys.flags(' |
|
assert 'text/latex' not in app.user_ns['b'] |
|
assert app.user_ns['c']['text/plain'] == \ |
|
"""\ |
|
[1] \n\ |
|
([ ],) |
|
[2] \ |
|
""" |
|
assert app.user_ns['c']['text/latex'] == '$\\displaystyle \\left( \\left[\\begin{matrix}1\\\\2\\end{matrix}\\right],\\right)$' |
|
else: |
|
assert app.user_ns['a'][0]['text/plain'] == '(True, False)' |
|
assert 'text/latex' not in app.user_ns['a'][0] |
|
assert app.user_ns['b'][0]['text/plain'][:10] == 'sys.flags(' |
|
assert 'text/latex' not in app.user_ns['b'][0] |
|
assert app.user_ns['c'][0]['text/plain'] == \ |
|
"""\ |
|
[1] \n\ |
|
([ ],) |
|
[2] \ |
|
""" |
|
assert app.user_ns['c'][0]['text/latex'] == '$\\displaystyle \\left( \\left[\\begin{matrix}1\\\\2\\end{matrix}\\right],\\right)$' |
|
|
|
def test_matplotlib_bad_latex(): |
|
|
|
app = init_ipython_session() |
|
app.run_cell("import IPython") |
|
app.run_cell("ip = get_ipython()") |
|
app.run_cell("inst = ip.instance()") |
|
app.run_cell("format = inst.display_formatter.format") |
|
app.run_cell("from sympy import init_printing, Matrix") |
|
app.run_cell("init_printing(use_latex='matplotlib')") |
|
|
|
|
|
app.run_cell("inst.display_formatter.formatters['image/png'].enabled = True") |
|
|
|
|
|
app.run_cell("import warnings") |
|
|
|
if int(ipython.__version__.split(".")[0]) < 2: |
|
app.run_cell("warnings.simplefilter('error')") |
|
else: |
|
app.run_cell("warnings.simplefilter('error', IPython.core.formatters.FormatterWarning)") |
|
|
|
|
|
app.run_cell("a = format(Matrix([1, 2, 3]))") |
|
|
|
|
|
app.run_cell("from sympy import Piecewise, Symbol, Eq") |
|
app.run_cell("x = Symbol('x'); pw = format(Piecewise((1, Eq(x, 0)), (0, True)))") |
|
|
|
|
|
def test_override_repr_latex(): |
|
|
|
app = init_ipython_session() |
|
app.run_cell("import IPython") |
|
app.run_cell("ip = get_ipython()") |
|
app.run_cell("inst = ip.instance()") |
|
app.run_cell("format = inst.display_formatter.format") |
|
app.run_cell("inst.display_formatter.formatters['text/latex'].enabled = True") |
|
app.run_cell("from sympy import init_printing") |
|
app.run_cell("from sympy import Symbol") |
|
app.run_cell("init_printing(use_latex=True)") |
|
app.run_cell("""\ |
|
class SymbolWithOverload(Symbol): |
|
def _repr_latex_(self): |
|
return r"Hello " + super()._repr_latex_() + " world" |
|
""") |
|
app.run_cell("a = format(SymbolWithOverload('s'))") |
|
|
|
if int(ipython.__version__.split(".")[0]) < 1: |
|
latex = app.user_ns['a']['text/latex'] |
|
else: |
|
latex = app.user_ns['a'][0]['text/latex'] |
|
assert latex == r'Hello $\displaystyle s$ world' |
|
|